未验证 提交 7e049108 编写于 作者: L Leo Chen 提交者: GitHub

[feature] support npu operator (#30951)

[feature] support npu operator
上级 81138239
......@@ -11,12 +11,16 @@ function(op_library TARGET)
set(miopen_hip_cc_srcs)
set(cu_cc_srcs)
set(xpu_cc_srcs)
set(npu_cc_srcs)
set(cudnn_cu_cc_srcs)
set(cudnn_cu_srcs)
set(CUDNN_FILE)
set(mkldnn_cc_srcs)
set(MKLDNN_FILE)
set(op_common_deps operator op_registry math_function layer common_infer_shape_functions)
if (WITH_ASCEND_CL)
set(op_common_deps ${op_common_deps} npu_op_runner)
endif()
# Option `UNITY` is used to specify that operator `TARGET` will compiles with Unity Build.
set(options UNITY)
set(oneValueArgs "")
......@@ -84,6 +88,12 @@ function(op_library TARGET)
list(APPEND xpu_cc_srcs ${XPU_FILE}.cc)
endif()
endif()
if(WITH_ASCEND_CL)
string(REPLACE "_op" "_op_npu" NPU_FILE "${TARGET}")
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${NPU_FILE}.cc)
list(APPEND npu_cc_srcs ${NPU_FILE}.cc)
endif()
endif()
else()
foreach(src ${op_library_SRCS})
if (WITH_ROCM_PLATFORM AND ${src} MATCHES ".*\\.hip.cu$")
......@@ -106,6 +116,8 @@ function(op_library TARGET)
list(APPEND cu_cc_srcs ${src})
elseif(WITH_XPU AND ${src} MATCHES ".*_op_xpu.cc$")
list(APPEND xpu_cc_srcs ${src})
elseif(WITH_ASCEND_CL AND ${src} MATCHES ".*_op_npu.cc$")
list(APPEND npu_cc_srcs ${src})
elseif(${src} MATCHES ".*\\.cc$")
list(APPEND cc_srcs ${src})
else()
......@@ -170,7 +182,7 @@ function(op_library TARGET)
# Unity Build relies on global option `WITH_UNITY_BUILD` and local option `UNITY`.
if(WITH_UNITY_BUILD AND op_library_UNITY)
# Combine the cc source files.
compose_unity_target_sources(${UNITY_TARGET} cc ${cc_srcs} ${mkldnn_cc_srcs} ${xpu_cc_srcs})
compose_unity_target_sources(${UNITY_TARGET} cc ${cc_srcs} ${mkldnn_cc_srcs} ${xpu_cc_srcs} ${npu_cc_srcs})
if(TARGET ${UNITY_TARGET})
# If `UNITY_TARGET` exists, add source files to `UNITY_TARGET`.
target_sources(${UNITY_TARGET} PRIVATE ${unity_target_cc_sources})
......@@ -181,7 +193,7 @@ function(op_library TARGET)
# Add alias library to handle dependencies.
add_library(${TARGET} ALIAS ${UNITY_TARGET})
else()
cc_library(${TARGET} SRCS ${cc_srcs} ${mkldnn_cc_srcs} ${xpu_cc_srcs} DEPS ${op_library_DEPS}
cc_library(${TARGET} SRCS ${cc_srcs} ${mkldnn_cc_srcs} ${xpu_cc_srcs} ${npu_cc_srcs} DEPS ${op_library_DEPS}
${op_common_deps})
endif()
endif()
......@@ -230,10 +242,11 @@ function(op_library TARGET)
list(LENGTH cu_cc_srcs cu_cc_srcs_len)
list(LENGTH mkldnn_cc_srcs mkldnn_cc_srcs_len)
list(LENGTH xpu_cc_srcs xpu_cc_srcs_len)
list(LENGTH npu_cc_srcs npu_cc_srcs_len)
list(LENGTH hip_cu_srcs hip_cu_srcs_len)
list(LENGTH miopen_hip_cc_srcs miopen_hip_cc_srcs_len)
if (${pybind_flag} EQUAL 0 AND ${mkldnn_cc_srcs_len} EQUAL 0 AND ${cu_srcs_len} EQUAL 0 AND ${cu_cc_srcs_len} EQUAL 0 AND
${hip_cu_srcs_len} EQUAL 0 AND ${miopen_hip_cc_srcs_len} EQUAL 0 AND ${xpu_cc_srcs_len} EQUAL 0)
${hip_cu_srcs_len} EQUAL 0 AND ${miopen_hip_cc_srcs_len} EQUAL 0 AND ${xpu_cc_srcs_len} EQUAL 0 AND ${npu_cc_srcs_len} EQUAL 0)
file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(${TARGET});\n")
set(pybind_flag 1)
endif()
......@@ -273,6 +286,9 @@ function(op_library TARGET)
if (WITH_XPU AND ${xpu_cc_srcs_len} GREATER 0)
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, XPU);\n")
endif()
if (WITH_XPU AND ${npu_cc_srcs_len} GREATER 0)
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, NPU);\n")
endif()
# pybind USE_OP_DEVICE_KERNEL for MKLDNN
if (WITH_MKLDNN AND ${mkldnn_cc_srcs_len} GREATER 0)
# Append first implemented MKLDNN activation operator
......@@ -323,6 +339,7 @@ function(register_operators)
file(GLOB OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*_op.cc")
string(REPLACE "_mkldnn" "" OPS "${OPS}")
string(REPLACE "_xpu" "" OPS "${OPS}")
string(REPLACE "_npu" "" OPS "${OPS}")
string(REPLACE ".cc" "" OPS "${OPS}")
list(REMOVE_DUPLICATES OPS)
list(LENGTH register_operators_DEPS register_operators_DEPS_len)
......
......@@ -61,6 +61,8 @@ inline LibraryType StringToLibraryType(const char* ctype) {
return LibraryType::kPlain;
} else if (s == std::string("XPU")) {
return LibraryType::kPlain;
} else if (s == std::string("NPU")) {
return LibraryType::kPlain;
} else if (s == std::string("CUDA")) {
return LibraryType::kPlain;
} else {
......
......@@ -304,6 +304,9 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
#define REGISTER_OP_XPU_KERNEL(op_type, ...) \
REGISTER_OP_KERNEL(op_type, XPU, ::paddle::platform::XPUPlace, __VA_ARGS__)
#define REGISTER_OP_NPU_KERNEL(op_type, ...) \
REGISTER_OP_KERNEL(op_type, NPU, ::paddle::platform::NPUPlace, __VA_ARGS__)
#define REGISTER_OP_KERNEL_EX(op_type, library_type, place_class, \
customized_name, \
customized_type_value, \
......
......@@ -212,6 +212,16 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
#else
auto dev_id = BOOST_GET_CONST(platform::XPUPlace, place).device;
platform::SetXPUDeviceId(dev_id);
#endif
} else if (platform::is_npu_place(place)) {
#ifndef PADDLE_WITH_ASCEND_CL
PADDLE_THROW(platform::errors::Unavailable(
"Cannot run operator on place %s, please recompile paddle or "
"reinstall Paddle with NPU support.",
place));
#else
auto dev_id = BOOST_GET_CONST(platform::NPUPlace, place).device;
platform::SetNPUDeviceId(dev_id);
#endif
}
......
......@@ -125,25 +125,54 @@ TEST(Tensor, MutableData) {
float* p2 = nullptr;
// initialization
p1 = src_tensor.mutable_data<float>(framework::make_ddim({1, 2, 3}),
platform::CUDAPlace());
platform::CUDAPlace(0));
auto p1_holder = src_tensor.Holder();
EXPECT_NE(p1, nullptr);
// set src_tensor a new dim with large size
// momery is supposed to be re-allocated
p2 = src_tensor.mutable_data<float>(framework::make_ddim({3, 1024}),
platform::CUDAPlace());
platform::CUDAPlace(0));
auto p2_holder = src_tensor.Holder();
EXPECT_NE(p2, nullptr);
EXPECT_NE(p1_holder.get(), p2_holder.get());
// set src_tensor a new dim with same size
// momery block is supposed to be unchanged
p1 = src_tensor.mutable_data<float>(framework::make_ddim({2, 2, 3}),
platform::CUDAPlace());
platform::CUDAPlace(0));
EXPECT_EQ(p1, p2);
// set src_tensor a new dim with smaller size
// momery block is supposed to be unchanged
p2 = src_tensor.mutable_data<float>(framework::make_ddim({2, 2}),
platform::CUDAPlace());
platform::CUDAPlace(0));
EXPECT_EQ(p1, p2);
}
#endif
#ifdef PADDLE_WITH_ASCEND_CL
{
framework::Tensor src_tensor;
float* p1 = nullptr;
float* p2 = nullptr;
// initialization
p1 = src_tensor.mutable_data<float>(framework::make_ddim({1, 2, 3}),
platform::NPUPlace(0));
auto p1_holder = src_tensor.Holder();
EXPECT_NE(p1, nullptr);
// set src_tensor a new dim with large size
// momery is supposed to be re-allocated
p2 = src_tensor.mutable_data<float>(framework::make_ddim({3, 1024}),
platform::NPUPlace(0));
auto p2_holder = src_tensor.Holder();
EXPECT_NE(p2, nullptr);
EXPECT_NE(p1_holder.get(), p2_holder.get());
// set src_tensor a new dim with same size
// momery block is supposed to be unchanged
p1 = src_tensor.mutable_data<float>(framework::make_ddim({2, 2, 3}),
platform::NPUPlace(0));
EXPECT_EQ(p1, p2);
// set src_tensor a new dim with smaller size
// momery block is supposed to be unchanged
p2 = src_tensor.mutable_data<float>(framework::make_ddim({2, 2}),
platform::NPUPlace(0));
EXPECT_EQ(p1, p2);
}
#endif
......@@ -179,7 +208,17 @@ TEST(Tensor, ShareDataWith) {
framework::Tensor src_tensor;
framework::Tensor dst_tensor;
src_tensor.mutable_data<int>(framework::make_ddim({2, 3, 4}),
platform::CUDAPlace());
platform::CUDAPlace(0));
dst_tensor.ShareDataWith(src_tensor);
ASSERT_EQ(src_tensor.data<int>(), dst_tensor.data<int>());
}
#endif
#ifdef PADDLE_WITH_ASCEND_CL
{
framework::Tensor src_tensor;
framework::Tensor dst_tensor;
src_tensor.mutable_data<int>(framework::make_ddim({2, 3, 4}),
platform::NPUPlace(0));
dst_tensor.ShareDataWith(src_tensor);
ASSERT_EQ(src_tensor.data<int>(), dst_tensor.data<int>());
}
......@@ -216,7 +255,34 @@ TEST(Tensor, Slice) {
{
framework::Tensor src_tensor;
src_tensor.mutable_data<double>(framework::make_ddim({6, 9}),
platform::CUDAPlace());
platform::CUDAPlace(0));
framework::Tensor slice_tensor = src_tensor.Slice(2, 6);
framework::DDim slice_dims = slice_tensor.dims();
ASSERT_EQ(arity(slice_dims), 2);
EXPECT_EQ(slice_dims[0], 4);
EXPECT_EQ(slice_dims[1], 9);
uintptr_t src_data_address =
reinterpret_cast<uintptr_t>(src_tensor.data<double>());
uintptr_t src_mutable_data_address =
reinterpret_cast<uintptr_t>(src_tensor.mutable_data<double>(
src_tensor.dims(), platform::CUDAPlace(0)));
uintptr_t slice_data_address =
reinterpret_cast<uintptr_t>(slice_tensor.data<double>());
uintptr_t slice_mutable_data_address =
reinterpret_cast<uintptr_t>(slice_tensor.mutable_data<double>(
slice_tensor.dims(), platform::CUDAPlace(0)));
EXPECT_EQ(src_data_address, src_mutable_data_address);
EXPECT_EQ(slice_data_address, slice_mutable_data_address);
EXPECT_EQ(src_data_address + 9 * 2 * sizeof(double), slice_data_address);
}
#endif
#ifdef PADDLE_WITH_ASCEND_CL
{
framework::Tensor src_tensor;
src_tensor.mutable_data<double>(framework::make_ddim({6, 9}),
platform::NPUPlace(0));
framework::Tensor slice_tensor = src_tensor.Slice(2, 6);
framework::DDim slice_dims = slice_tensor.dims();
ASSERT_EQ(arity(slice_dims), 2);
......@@ -227,12 +293,12 @@ TEST(Tensor, Slice) {
reinterpret_cast<uintptr_t>(src_tensor.data<double>());
uintptr_t src_mutable_data_address =
reinterpret_cast<uintptr_t>(src_tensor.mutable_data<double>(
src_tensor.dims(), platform::CUDAPlace()));
src_tensor.dims(), platform::NPUPlace(0)));
uintptr_t slice_data_address =
reinterpret_cast<uintptr_t>(slice_tensor.data<double>());
uintptr_t slice_mutable_data_address =
reinterpret_cast<uintptr_t>(slice_tensor.mutable_data<double>(
slice_tensor.dims(), platform::CUDAPlace()));
slice_tensor.dims(), platform::NPUPlace(0)));
EXPECT_EQ(src_data_address, src_mutable_data_address);
EXPECT_EQ(slice_data_address, slice_mutable_data_address);
EXPECT_EQ(src_data_address + 9 * 2 * sizeof(double), slice_data_address);
......
......@@ -158,6 +158,14 @@ void TensorFromVector(const std::vector<T>& src,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream());
}
#endif
#ifdef PADDLE_WITH_ASCEND_CL
else if (platform::is_npu_place(dst_place)) { // NOLINT
memory::Copy(
BOOST_GET_CONST(platform::NPUPlace, dst_place), dst_ptr, src_place,
src_ptr, size,
reinterpret_cast<const platform::NPUDeviceContext&>(ctx).stream());
}
#endif
}
template <typename T>
......@@ -195,6 +203,14 @@ void TensorToVector(const Tensor& src, const platform::DeviceContext& ctx,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream());
}
#endif
#ifdef PADDLE_WITH_ASCEND_CL
else if (platform::is_npu_place(src.place())) { // NOLINT
memory::Copy(
dst_place, dst_ptr, BOOST_GET_CONST(platform::NPUPlace, src.place()),
src_ptr, size,
reinterpret_cast<const platform::NPUDeviceContext&>(ctx).stream());
}
#endif
}
template <typename T>
......
......@@ -198,6 +198,85 @@ void Copy<platform::XPUPlace, platform::XPUPlace>(platform::XPUPlace dst_place,
}
#endif
#ifdef PADDLE_WITH_ASCEND_CL
template <>
void Copy<platform::NPUPlace, platform::CPUPlace>(platform::NPUPlace dst_place,
void* dst,
platform::CPUPlace src_place,
const void* src, size_t num,
aclrtStream stream) {
if (UNLIKELY(num == 0)) return;
platform::SetNPUDeviceId(dst_place.device);
VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
<< dst_place << " by thream(" << stream << ")";
if (stream) {
platform::RecordEvent record_event("NpuMemcpyAsync:CPU->NPU");
platform::NPUMemcpyAsync(dst, src, num, ACL_MEMCPY_HOST_TO_DEVICE, stream);
} else {
platform::RecordEvent record_event("NpuMemcpySync:CPU->NPU");
platform::NPUMemcpySync(dst, src, num, ACL_MEMCPY_HOST_TO_DEVICE);
}
}
template <>
void Copy<platform::CPUPlace, platform::NPUPlace>(platform::CPUPlace dst_place,
void* dst,
platform::NPUPlace src_place,
const void* src, size_t num,
aclrtStream stream) {
if (UNLIKELY(num == 0)) return;
platform::SetNPUDeviceId(src_place.device);
VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
<< dst_place << " by thream(" << stream << ")";
if (stream) {
platform::RecordEvent record_event("NpuMemcpyAsync:NPU->CPU");
platform::NPUMemcpyAsync(dst, src, num, ACL_MEMCPY_DEVICE_TO_HOST, stream);
} else {
platform::RecordEvent record_event("GpuMemcpySync:NPU->CPU");
platform::NPUMemcpySync(dst, src, num, ACL_MEMCPY_DEVICE_TO_HOST);
}
}
template <>
void Copy<platform::NPUPlace, platform::NPUPlace>(platform::NPUPlace dst_place,
void* dst,
platform::NPUPlace src_place,
const void* src, size_t num,
aclrtStream stream) {
if (UNLIKELY(num == 0)) return;
VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
<< dst_place << " by stream(" << stream << ")";
if (dst_place == src_place) {
platform::SetNPUDeviceId(src_place.device);
if (stream) {
platform::RecordEvent record_event("NpuMemcpyAsync(same_npu):NPU->NPU");
platform::NPUMemcpyAsync(dst, src, num, ACL_MEMCPY_DEVICE_TO_DEVICE,
stream);
} else {
platform::RecordEvent record_event("NpuMemcpySync(same_npu):NPU->NPU");
platform::NPUMemcpySync(dst, src, num, ACL_MEMCPY_DEVICE_TO_DEVICE);
}
} else {
if (!platform::NPUCanAccessPeer(dst_place.device, dst_place.device)) {
PADDLE_THROW(platform::errors::Unavailable(
"Peer access between NPU places is not allowed."));
}
if (stream) {
// TODO(zhiqiu): support peer access?
platform::RecordEvent record_event("NpuMemcpyPeerAsync:NPU->NPU");
platform::NPUMemcpyAsync(dst, src, num, ACL_MEMCPY_DEVICE_TO_DEVICE,
stream);
} else {
platform::RecordEvent record_event("NpuMemcpyPeerSync:NPU->NPU");
platform::NPUMemcpySync(dst, src, num, ACL_MEMCPY_DEVICE_TO_DEVICE);
}
}
}
#endif
#ifdef PADDLE_WITH_CUDA
static constexpr size_t kMaxGpuAsyncCopyBytes = 64 * 1024; // 64K
......
......@@ -52,6 +52,26 @@ void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num);
template <typename DstPlace, typename SrcPlace>
void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num,
cudaStream_t stream);
#endif
#ifdef PADDLE_WITH_ASCEND_CL
/**
* \brief Copy memory from one place to another place.
*
* \param[in] DstPlace Destination allocation place (CPU or NPU).
* \param[in] dst Destination memory address.
* \param[in] SrcPlace Source allocation place (CPU or NPU).
* \param[in] src Source memory address.
* \param[in] num memory size in bytes to copy.
* \param[in] stream NPU stream.
*
* \note For NPU memory copy, NPU stream need to be specified
* for asynchronously memory copy.
*
*/
template <typename DstPlace, typename SrcPlace>
void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num,
aclrtStream stream);
#endif
} // namespace memory
......
......@@ -119,6 +119,11 @@ if (WITH_ASCEND)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} ascend_wrapper)
endif()
if (WITH_ASCEND_CL)
cc_library(npu_op_runner SRCS npu_op_runner.cc DEPS operator npu_info)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} npu_op_runner)
endif()
# FIXME(typhoonzero): operator deps may not needed.
# op_library(lod_tensor_to_array_op DEPS lod_rank_table_op)
# op_library(array_to_lod_tensor_op DEPS lod_rank_table_op)
......
......@@ -8,3 +8,4 @@ register_operators(DEPS op_version_registry)
cc_test(test_elementwise_add_op_inplace SRCS test_elementwise_add_op_inplace.cc DEPS op_registry elementwise_add_op scope device_context enforce executor)
cc_test(test_elementwise_div_grad_grad SRCS test_elementwise_div_grad_grad.cc DEPS op_registry elementwise_div_op scope device_context enforce executor)
cc_test(test_elementwise_add_grad_grad SRCS test_elementwise_add_grad_grad.cc DEPS op_registry elementwise_add_op scope device_context enforce executor)
cc_test(elementwise_add_op_npu_test SRCS elementwise_add_op_npu_test.cc DEPS op_registry elementwise_add_op scope device_context enforce executor)
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_ASCEND_CL
#include <memory>
#include <string>
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
#include "paddle/fluid/operators/npu_op_runner.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class ElementwiseAddNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<framework::LoDTensor>("X");
auto* y = ctx.Input<framework::LoDTensor>("Y");
auto* out = ctx.Output<framework::LoDTensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
// TODO(zhiqiu): get the attr infomation of Ascend op and
// convert paddle AttributeMap to Ascend attrs.
// Ascend op add has no attribute ?
// int axis = ctx.Attr<int>("axis");
// NOTE(zhiqiu): the order of inputs and outputs is important
auto runner = NpuOpRunner("Add", {*x, *y}, {*out}, {});
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
runner.Run(stream);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_NPU_KERNEL(
elementwise_add,
ops::ElementwiseAddNPUKernel<paddle::platform::NPUDeviceContext, float>);
#endif
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef _WIN32
#include <unistd.h>
#endif
#include <string>
#include <thread> // NOLINT
#include <vector>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/string/printf.h"
namespace f = paddle::framework;
namespace p = paddle::platform;
namespace m = paddle::operators::math;
USE_OP(elementwise_add);
USE_OP_DEVICE_KERNEL(elementwise_add, NPU);
void Compare(f::Scope* scope, const p::DeviceContext& ctx) {
// init
auto x = scope->Var("X");
auto tensor_x = x->GetMutable<f::LoDTensor>();
tensor_x->Resize({10, 10});
auto y = scope->Var("Y");
auto tensor_y = y->GetMutable<f::LoDTensor>();
tensor_y->Resize({10, 10});
std::vector<float> init;
for (int64_t i = 0; i < 10 * 10; ++i) {
init.push_back(1.0);
}
TensorFromVector(init, ctx, tensor_x);
TensorFromVector(init, ctx, tensor_y);
auto place = ctx.GetPlace();
auto out = scope->Var("Out");
auto tensor_out = out->GetMutable<f::LoDTensor>();
tensor_out->Resize({10, 10});
tensor_out->mutable_data<float>(place); // allocate
// run
f::AttributeMap attrs;
auto op =
f::OpRegistry::CreateOp("elementwise_add", {{"X", {"X"}}, {"Y", {"Y"}}},
{{"Out", {"Out"}}}, attrs);
op->Run(*scope, place);
std::vector<float> out_vec;
TensorToVector(*tensor_out, ctx, &out_vec);
EXPECT_EQ(out_vec.size(), init.size());
for (uint32_t i = 0; i < out_vec.size(); i++) {
EXPECT_EQ(out_vec[i], 2.0);
}
}
TEST(elementwise_add, NPU) {
f::Scope scope;
p::NPUDeviceContext ctx(p::NPUPlace(0));
Compare(&scope, ctx);
}
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/npu_op_runner.h"
#include <paddle/fluid/framework/operator.h>
#include <paddle/fluid/framework/data_type.h>
#include <map>
#include <string>
#include <vector>
#include "acl/acl.h"
#include "paddle/fluid/framework/framework.pb.h"
namespace paddle {
namespace operators {
static std::map<framework::proto::VarType::Type, aclDataType> DTYPE_2_ACL_DTYPE = {
{framework::proto::VarType::BOOL, ACL_BOOL}, {framework::proto::VarType::INT16, ACL_INT16},
{framework::proto::VarType::INT32, ACL_INT32}, {framework::proto::VarType::INT64, ACL_INT64},
{framework::proto::VarType::FP16, ACL_FLOAT16}, {framework::proto::VarType::FP32, ACL_FLOAT},
{framework::proto::VarType::FP64, ACL_DOUBLE},
};
static std::map<DataLayout, aclFormat> DATA_LAYOUT_2_ACL_FORMAT = {
{DataLayout::kNCHW, ACL_FORMAT_NCHW},
{DataLayout::kNHWC, ACL_FORMAT_NHWC},
{DataLayout::kAnyLayout, ACL_FORMAT_ND},
};
aclDataType ConvertToNpuDtype(framework::proto::VarType::Type dtype) {
auto iter = DTYPE_2_ACL_DTYPE.find(dtype);
PADDLE_ENFORCE_NE(iter, DTYPE_2_ACL_DTYPE.end(),
platform::errors::NotFound(
"The data type (%s) can not convert to ACL data type.",
framework::DataTypeToString(dtype)));
return iter->second;
}
aclFormat ConvertToNpuFormat(DataLayout layout) {
auto iter = DATA_LAYOUT_2_ACL_FORMAT.find(layout);
PADDLE_ENFORCE_NE(
iter, DATA_LAYOUT_2_ACL_FORMAT.end(),
platform::errors::NotFound(
"The data type (%s) can not convert to ACL data type.", layout));
return iter->second;
}
NpuOpRunner::NpuOpRunner(std::string op_type) : op_type_(op_type) {}
NpuOpRunner::NpuOpRunner(std::string op_type, const std::vector<Tensor> &inputs,
const std::vector<Tensor> &outputs,
const AttributeMap &attrs)
: op_type_(op_type) {
AddInputs(inputs);
AddOutputs(outputs);
AddAttrs(attrs);
}
NpuOpRunner::~NpuOpRunner() {
//TODO(zhiqiu): handle free
}
const std::string &NpuOpRunner::Type() { return op_type_; }
NpuOpRunner &NpuOpRunner::AddAttr(const std::string &name,
const Attribute &attr) {
if (attr.type() == typeid(bool)) {
PADDLE_ENFORCE_NPU_SUCCESS(
aclopSetAttrBool(attr_, name.c_str(), BOOST_GET_CONST(bool, attr)));
} else if (attr.type() == typeid(int)) {
PADDLE_ENFORCE_NPU_SUCCESS(
aclopSetAttrInt(attr_, name.c_str(), BOOST_GET_CONST(int, attr)));
} else if (attr.type() == typeid(int64_t)) {
PADDLE_ENFORCE_NPU_SUCCESS(aclopSetAttrInt(
attr_, name.c_str(), BOOST_GET_CONST(int64_t, attr)));
} else if (attr.type() == typeid(float)) {
PADDLE_ENFORCE_NPU_SUCCESS(
aclopSetAttrFloat(attr_, name.c_str(), BOOST_GET_CONST(float, attr)));
} else if (attr.type() == typeid(std::vector<bool>)) {
auto a = BOOST_GET_CONST(std::vector<bool>, attr);
std::vector<uint8_t> cast_a;
for(auto it : a) {
cast_a.push_back(static_cast<uint8_t>(it));
}
PADDLE_ENFORCE_NPU_SUCCESS(
aclopSetAttrListBool(attr_, name.c_str(), cast_a.size(), cast_a.data()));
} else if (attr.type() == typeid(std::vector<int>)) {
auto a = BOOST_GET_CONST(std::vector<int>, attr);
std::vector<int64_t> cast_a;
for(auto it : a) {
cast_a.push_back(static_cast<int64_t>(it));
}
PADDLE_ENFORCE_NPU_SUCCESS(
aclopSetAttrListInt(attr_, name.c_str(), cast_a.size(), cast_a.data()));
} else if (attr.type() == typeid(std::vector<int64_t>)) {
auto a = BOOST_GET_CONST(std::vector<int64_t>, attr);
PADDLE_ENFORCE_NPU_SUCCESS(
aclopSetAttrListInt(attr_, name.c_str(), a.size(), a.data()));
} else if (attr.type() == typeid(std::vector<float>)) {
auto a = BOOST_GET_CONST(std::vector<float>, attr);
PADDLE_ENFORCE_NPU_SUCCESS(
aclopSetAttrListFloat(attr_, name.c_str(), a.size(), a.data()));
} else if (attr.type() == typeid(std::string)) {
auto a = BOOST_GET_CONST(std::string, attr);
PADDLE_ENFORCE_NPU_SUCCESS(
aclopSetAttrString(attr_, name.c_str(), a.c_str()));
} else if (attr.type() == typeid(std::vector<std::string>)) {
auto a = BOOST_GET_CONST(std::vector<std::string>, attr);
std::vector<const char *> s;
for (auto &it : a) {
s.push_back(it.data());
}
PADDLE_ENFORCE_NPU_SUCCESS(
aclopSetAttrListString(attr_, name.c_str(), s.size(), s.data()));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Can not convert attribubte '%s' to convert to aclopAttr", name));
}
return *this;
}
NpuOpRunner &NpuOpRunner::AddAttrs(const AttributeMap &attrs) {
for (const auto &pair : attrs) {
AddAttr(pair.first, pair.second);
}
return *this;
}
NpuOpRunner &NpuOpRunner::AddInput(const Tensor &tensor) {
// create aclTensorDesc
input_descs_.emplace_back(CreateTensorDesc(tensor));
// create aclDataBuffer
input_buffers_.emplace_back(CreateDataBuffer(tensor));
return *this;
}
NpuOpRunner &NpuOpRunner::AddOutput(const Tensor &tensor) {
// create aclTensorDesc
output_descs_.emplace_back(CreateTensorDesc(tensor));
// create aclDataBuffer
output_buffers_.emplace_back(CreateDataBuffer(tensor));
return *this;
}
NpuOpRunner &NpuOpRunner::AddInputs(const std::vector<Tensor> &tensors) {
for (auto tensor : tensors) {
// create aclTensorDesc
input_descs_.emplace_back(CreateTensorDesc(tensor));
// create aclDataBuffer
input_buffers_.emplace_back(CreateDataBuffer(tensor));
}
return *this;
}
NpuOpRunner &NpuOpRunner::AddOutputs(const std::vector<Tensor> &tensors) {
for (auto tensor : tensors) {
// create aclTensorDesc
output_descs_.emplace_back(CreateTensorDesc(tensor));
// create aclDataBuffer
output_buffers_.emplace_back(CreateDataBuffer(tensor));
}
return *this;
}
aclTensorDesc *NpuOpRunner::GetInputDesc(size_t index) {
PADDLE_ENFORCE_LT(index, input_descs_.size(),
platform::errors::OutOfRange(
"The index should be less than the size of inputs of "
"operator %s, but got index is %d and size is %d",
Type(), index, input_descs_.size()));
return input_descs_[index];
}
aclTensorDesc *NpuOpRunner::GetOutputDesc(size_t index) {
PADDLE_ENFORCE_LT(index, output_descs_.size(),
platform::errors::OutOfRange(
"The index should be less than the size of output of "
"operator %s, but got index is %d and size is %d",
Type(), index, output_descs_.size()));
return output_descs_[index];
}
std::vector<aclTensorDesc *> &NpuOpRunner::GetInputDescs() {
return input_descs_;
}
std::vector<aclTensorDesc *> &NpuOpRunner::GetOutputDescs() {
return output_descs_;
}
std::vector<aclDataBuffer *> &NpuOpRunner::GetInputBuffers() { return input_buffers_; }
std::vector<aclDataBuffer *> &NpuOpRunner::GetOutputBuffers() { return output_buffers_; }
aclTensorDesc *NpuOpRunner::CreateTensorDesc(Tensor tensor) {
auto dtype = ConvertToNpuDtype(tensor.type());
auto format = ConvertToNpuFormat(tensor.layout());
auto dims = framework::vectorize(tensor.dims());
auto *desc = aclCreateTensorDesc(dtype, dims.size(), dims.data(), format);
PADDLE_ENFORCE_NOT_NULL(
desc, platform::errors::External("Call aclCreateTensorDesc failed."));
return desc;
}
aclDataBuffer *NpuOpRunner::CreateDataBuffer(Tensor tensor) {
auto *buffer =
aclCreateDataBuffer(tensor.Holder()->ptr(), tensor.memory_size());
PADDLE_ENFORCE_NOT_NULL(
buffer, platform::errors::External("Call aclCreateDataBuffer failed."));
return buffer;
}
void NpuOpRunner::Run(aclrtStream stream) {
aclError ret = aclopExecuteV2(op_type_.c_str(), input_descs_.size(),
input_descs_.data(), input_buffers_.data(),
output_descs_.size(), output_descs_.data(),
output_buffers_.data(), attr_, stream);
PADDLE_ENFORCE_NPU_SUCCESS(ret);
}
} // namespace operators
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <paddle/fluid/framework/operator.h>
#include <string>
#include <vector>
#include "acl/acl.h"
#include "paddle/fluid/operators/npu_op_runner.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using DataLayout = framework::DataLayout;
using Attribute = framework::Attribute;
using AttributeMap = framework::AttributeMap;
class NpuOpRunner {
public:
explicit NpuOpRunner(std::string op_type);
explicit NpuOpRunner(std::string op_type, const std::vector<Tensor> &inputs = {},
const std::vector<Tensor> &outputs = {},
const AttributeMap &attrs = {});
~NpuOpRunner();
const std::string &Type();
NpuOpRunner &AddAttr(const std::string& name, const Attribute &attr);
NpuOpRunner &AddAttrs(const AttributeMap &attrs);
NpuOpRunner &AddInput(const Tensor &tensor);
NpuOpRunner &AddOutput(const Tensor &tensor);
NpuOpRunner &AddInputs(const std::vector<Tensor> &tensors);
NpuOpRunner &AddOutputs(const std::vector<Tensor> &tensors);
aclTensorDesc *GetInputDesc(size_t index);
aclTensorDesc *GetOutputDesc(size_t index);
std::vector<aclTensorDesc *> &GetInputDescs();
std::vector<aclTensorDesc *> &GetOutputDescs();
std::vector<aclDataBuffer *> &GetInputBuffers();
std::vector<aclDataBuffer *> &GetOutputBuffers();
void Run(aclrtStream stream);
private:
aclTensorDesc *CreateTensorDesc(Tensor tensor);
aclDataBuffer *CreateDataBuffer(Tensor tensor);
private:
std::string op_type_;
std::vector<aclDataBuffer *> input_buffers_;
std::vector<aclDataBuffer *> output_buffers_;
std::vector<aclTensorDesc *> input_descs_;
std::vector<aclTensorDesc *> output_descs_;
aclopAttr *attr_;
};
} // namespace operators
} // namespace paddle
......@@ -90,15 +90,28 @@ IF(WITH_GPU)
set(GPU_CTX_DEPS dynload_cuda dynamic_loader cuda_stream)
ENDIF()
IF(WITH_ASCEND_CL)
set(NPU_CTX_DEPS npu_stream npu_info)
ENDIF()
IF(WITH_MKLDNN)
set(MKLDNN_CTX_DEPS mkldnn)
ELSE()
set(MKLDNN_CTX_DEPS)
ENDIF()
IF(WITH_GPU)
nv_library(stream_callback_manager SRCS stream_callback_manager.cc DEPS simple_threadpool enforce)
ENDIF()
IF(WITH_ASCEND_CL)
cc_library(stream_callback_manager SRCS stream_callback_manager.cc DEPS simple_threadpool enforce atlas_acl)
ENDIF()
IF(WITH_GPU)
set(STREAM_CALLBACK_DEPS stream_callback_manager)
ELSEIF(WITH_ASCEND_CL)
set(STREAM_CALLBACK_DEPS stream_callback_manager)
ELSE()
set(STREAM_CALLBACK_DEPS)
ENDIF()
......@@ -112,7 +125,7 @@ cc_library(cudnn_workspace_helper SRCS cudnn_workspace_helper.cc DEPS boost)
# memcpy depends on device_context, here add deps individually for
# avoiding cycle dependencies
cc_library(device_context SRCS device_context.cc init.cc DEPS simple_threadpool malloc xxhash ${STREAM_CALLBACK_DEPS}
place eigen3 stringpiece cpu_helper cpu_info framework_proto ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS}
place eigen3 stringpiece cpu_helper cpu_info framework_proto ${GPU_CTX_DEPS} ${NPU_CTX_DEPS} ${MKLDNN_CTX_DEPS}
${dgc_deps} dlpack cudnn_workspace_helper ${XPU_CTX_DEPS})
cc_library(collective_helper SRCS collective_helper.cc DEPS framework_proto device_context enforce)
......
......@@ -87,9 +87,8 @@ platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
if (it == device_contexts_.end()) {
PADDLE_THROW(platform::errors::Unimplemented(
"Place %s is not supported. Please check that your paddle compiles "
"with WITH_GPU or WITH_XPU option or check that your train process "
"hold the "
"correct gpu_id if you use Executor.",
"with WITH_GPU, WITH_XPU or WITH_ASCEND_CL option or check that "
"your train process set the correct device id if you use Executor.",
place));
}
return it->second.get().get();
......@@ -150,6 +149,14 @@ DeviceContextPool::DeviceContextPool(
PADDLE_THROW(
platform::errors::Unimplemented("XPUPlace is not supported. Please "
"re-compile with WITH_XPU option."));
#endif
} else if (platform::is_npu_place(p)) {
#ifdef PADDLE_WITH_ASCEND_CL
EmplaceDeviceContext<NPUDeviceContext, NPUPlace>(&device_contexts_, p);
#else
PADDLE_THROW(platform::errors::Unimplemented(
"NPUPlace is not supported. Please "
"re-compile with WITH_ASCEND_CL option."));
#endif
}
}
......@@ -679,6 +686,5 @@ MKLDNNDeviceContext::BlobPtr_t<void> MKLDNNDeviceContext::GetBlob(
}
#endif
} // namespace platform
} // namespace paddle
......@@ -47,6 +47,9 @@ limitations under the License. */
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/stream/cuda_stream.h"
#endif
#ifdef PADDLE_WITH_ASCEND_CL
#include "paddle/fluid/platform/stream/npu_stream.h"
#endif
#include "unsupported/Eigen/CXX11/Tensor"
namespace Eigen {
......
......@@ -181,6 +181,9 @@ void InitDevices(const std::vector<int> devices) {
#endif
#ifdef PADDLE_WITH_XPU
places.emplace_back(platform::XPUPlace(devices[i]));
#endif
#ifdef PADDLE_WITH_ASCEND_CL
places.emplace_back(platform::NPUPlace(devices[i]));
#endif
}
places.emplace_back(platform::CPUPlace());
......
IF(WITH_GPU)
cc_library(cuda_stream SRCS cuda_stream.cc DEPS enforce boost)
cc_library(cuda_stream SRCS cuda_stream.cc DEPS enforce boost stream_callback_manager)
ENDIF()
IF(WITH_ASCEND_CL)
cc_library(npu_stream SRCS npu_stream.cc DEPS enforce boost stream_callback_manager)
ENDIF()
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/stream/npu_stream.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/npu_info.h"
namespace paddle {
namespace platform {
namespace stream {
bool NPUStream::Init(const Place& place) {
PADDLE_ENFORCE_EQ(is_npu_place(place), true,
platform::errors::InvalidArgument(
"NPU stream must be created using npu place."));
place_ = place;
NPUDeviceGuard guard(BOOST_GET_CONST(NPUPlace, place_).device);
PADDLE_ENFORCE_NPU_SUCCESS(aclrtCreateStream(&stream_));
callback_manager_.reset(new StreamCallbackManager<aclrtStream>(stream_));
VLOG(3) << "NPUStream Init stream: " << stream_;
return true;
}
void NPUStream::Destroy() {
NPUDeviceGuard guard(BOOST_GET_CONST(NPUPlace, place_).device);
Wait();
WaitCallback();
if (stream_) {
PADDLE_ENFORCE_NPU_SUCCESS(aclrtDestroyStream(stream_));
}
stream_ = nullptr;
}
void NPUStream::Wait() const {
PADDLE_ENFORCE_NPU_SUCCESS(aclrtSynchronizeStream(stream_));
}
} // namespace stream
} // namespace platform
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cstdint>
#include <memory>
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/npu_info.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/stream_callback_manager.h"
namespace paddle {
namespace platform {
namespace stream {
#ifdef PADDLE_WITH_ASCEND_CL
class NPUStream final {
public:
NPUStream() = default;
explicit NPUStream(const Place& place) { Init(place); }
virtual ~NPUStream() { Destroy(); }
bool Init(const Place& place);
template <typename Callback>
void AddCallback(Callback&& callback) const {
callback_manager_->AddCallback(callback);
}
template <typename Callback>
void RecordEvent(aclrtEvent ev, Callback callback) const {
callback();
PADDLE_ENFORCE_NPU_SUCCESS(aclrtRecordEvent(ev, stream_));
}
void RecordEvent(aclrtEvent ev) const {
PADDLE_ENFORCE_NPU_SUCCESS(aclrtRecordEvent(ev, stream_));
}
void WaitEvent(aclrtEvent ev) const {
PADDLE_ENFORCE_NPU_SUCCESS(aclrtStreamWaitEvent(stream_, ev));
}
void Wait() const;
void WaitCallback() const { callback_manager_->Wait(); }
aclrtStream raw_stream() const { return stream_; }
void Destroy();
private:
Place place_;
aclrtStream stream_{nullptr};
std::unique_ptr<StreamCallbackManager<aclrtStream>> callback_manager_;
DISABLE_COPY_AND_ASSIGN(NPUStream);
};
#endif
} // namespace stream
} // namespace platform
} // namespace paddle
......@@ -19,22 +19,30 @@
namespace paddle {
namespace platform {
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 10000
static void CUDART_CB StreamCallbackFunc(void *user_data)
#else
static void CUDART_CB StreamCallbackFunc(cudaStream_t stream,
cudaError_t status, void *user_data)
#endif
#endif
#if PADDLE_WITH_ASCEND_CL
static void StreamCallbackFunc(void *user_data)
#endif
{
std::unique_ptr<std::function<void()>> func(
reinterpret_cast<std::function<void()> *>(user_data));
(*func)();
}
StreamCallbackManager::StreamCallbackManager(const cudaStream_t stream)
template <typename Stream>
StreamCallbackManager<Stream>::StreamCallbackManager(const Stream stream)
: stream_(stream), thread_pool_(1) {}
void StreamCallbackManager::AddCallback(std::function<void()> callback) const {
template <typename Stream>
void StreamCallbackManager<Stream>::AddCallback(std::function<void()> callback) const {
auto *callback_func = new std::function<void()>(std::move(callback));
auto *func = new std::function<void()>([this, callback_func] {
std::lock_guard<std::mutex> lock(mtx_);
......@@ -43,6 +51,7 @@ void StreamCallbackManager::AddCallback(std::function<void()> callback) const {
(*callback_func)();
});
});
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 10000
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaLaunchHostFunc(stream_, StreamCallbackFunc, func));
......@@ -50,10 +59,22 @@ void StreamCallbackManager::AddCallback(std::function<void()> callback) const {
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaStreamAddCallback(stream_, StreamCallbackFunc, func, 0));
#endif
#endif
#if PADDLE_WITH_ASCEND_CL
PADDLE_ENFORCE_NPU_SUCCESS(aclrtLaunchCallback(StreamCallbackFunc, func,
ACL_CALLBACK_BLOCK, stream_));
#endif
}
void StreamCallbackManager::Wait() const {
template <typename Stream>
void StreamCallbackManager<Stream>::Wait() const {
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream_));
#endif
#ifdef PADDLE_WITH_ASCEND_CL
PADDLE_ENFORCE_NPU_SUCCESS(aclrtSynchronizeStream(stream_));
#endif
{
std::lock_guard<std::mutex> lock(mtx_);
if (last_future_.valid()) {
......@@ -62,5 +83,12 @@ void StreamCallbackManager::Wait() const {
}
}
#ifdef PADDLE_WITH_CUDA
template struct StreamCallbackManager<cudaStream_t>;
#endif
#ifdef PADDLE_WITH_ASCEND_CL
template struct StreamCallbackManager<aclrtStream>;
#endif
} // namespace platform
} // namespace paddle
......@@ -15,8 +15,10 @@
#pragma once
#include <ThreadPool.h>
#ifdef PADDLE_WITH_CUDA
#include <cuda.h>
#include <cuda_runtime.h>
#endif
#include <functional>
#include <future> // NOLINT
#include <memory>
......@@ -29,9 +31,10 @@ namespace platform {
// NOTE(zjl): clean StreamCallbackManager to make compilation faster
// Make StreamCallbackManager thread-safe
template <typename Stream>
class StreamCallbackManager {
public:
explicit StreamCallbackManager(const cudaStream_t stream);
explicit StreamCallbackManager(const Stream stream);
~StreamCallbackManager() = default;
......@@ -40,7 +43,7 @@ class StreamCallbackManager {
void Wait() const;
private:
const cudaStream_t stream_;
const Stream stream_;
mutable ::ThreadPool thread_pool_;
mutable std::mutex mtx_;
mutable std::future<void> last_future_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册