未验证 提交 762819a8 编写于 作者: W WangXi 提交者: GitHub

[npu][hybrid] support offload (#37224)

上级 5237cc05
......@@ -151,6 +151,52 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place,
BOOST_GET_CONST(platform::NPUPlace, src_place), src_ptr, size,
stream);
}
else if (platform::is_npu_pinned_place(src_place) && // NOLINT
platform::is_npu_place(dst_place)) { /* npu_pinned->npu */
auto src_npu_pinned_place =
BOOST_GET_CONST(platform::NPUPinnedPlace, src_place);
auto dst_npu_place = BOOST_GET_CONST(platform::NPUPlace, dst_place);
auto ctx_place = ctx.GetPlace();
PADDLE_ENFORCE_EQ(platform::is_npu_place(ctx_place), true,
platform::errors::PreconditionNotMet(
"Device context place mismatch. When copying Tensor "
"data from NPU Pinned memory to NPU memory, current "
"device context place should be NPU."));
auto ctx_npu_place = BOOST_GET_CONST(platform::NPUPlace, ctx_place);
PADDLE_ENFORCE_EQ(dst_npu_place, ctx_npu_place,
platform::errors::PreconditionNotMet(
"The target NPU device and current device context do "
"not match. The target NPU device number is %d, but "
"device context NPU number is %d.",
dst_npu_place.device, ctx_npu_place.device));
auto stream =
reinterpret_cast<const platform::NPUDeviceContext&>(ctx).stream();
memory::Copy(dst_npu_place, dst_ptr, src_npu_pinned_place, src_ptr, size,
stream);
}
else if (platform::is_npu_place(src_place) && // NOLINT
platform::is_npu_pinned_place(dst_place)) { /* npu->npu_pinned */
auto src_npu_place = BOOST_GET_CONST(platform::NPUPlace, src_place);
auto dst_npu_pinned_place =
BOOST_GET_CONST(platform::NPUPinnedPlace, dst_place);
auto ctx_place = ctx.GetPlace();
PADDLE_ENFORCE_EQ(platform::is_npu_place(ctx_place), true,
platform::errors::PreconditionNotMet(
"Device context place mismatch. When copying Tensor "
"data from NPU memory to NPU Pinned memory, current "
"device context place should be NPU."));
auto ctx_npu_place = BOOST_GET_CONST(platform::NPUPlace, ctx_place);
PADDLE_ENFORCE_EQ(src_place, ctx_npu_place,
platform::errors::PreconditionNotMet(
"The source NPU device and current device context do "
"not match. The source NPU device number is %d, but "
"device context NPU number is %d.",
src_npu_place.device, ctx_npu_place.device));
auto stream =
reinterpret_cast<const platform::NPUDeviceContext&>(ctx).stream();
memory::Copy(dst_npu_pinned_place, dst_ptr, src_npu_place, src_ptr, size,
stream);
}
else { // NOLINT
PADDLE_THROW(platform::errors::Unimplemented(
"Copy from %s to %s is not supported.", src_place, dst_place));
......
......@@ -54,6 +54,17 @@ void NPUPinnedAllocator::FreeImpl(Allocation *allocation) {
std::lock_guard<std::mutex> lock(mtx_);
void *ptr = allocation->ptr();
auto iter = npu_events_.find(allocation);
// Managed by GC if not called RecordEvent.
if (iter == npu_events_.end()) {
// double free? No such problem has been found so far.
// Or maybe we need a set<Allocation*> to record which
// Allocation managed by GC.
free(ptr);
delete allocation;
return;
}
aclrtEvent event = iter->second;
aclrtEventStatus status = ACL_EVENT_STATUS_COMPLETE;
PADDLE_ENFORCE_NPU_SUCCESS(aclrtQueryEvent(event, &status));
......
......@@ -36,6 +36,16 @@ class SelectedRows;
namespace paddle {
namespace operators {
class MemcpyFunctor {
private:
enum DeviceType {
CPU = 0,
CUDA = 1,
CUDA_PINNED = 2,
XPU = 3,
NPU = 4,
NPU_PINNED = 5,
};
public:
MemcpyFunctor(framework::Variable *out,
const platform::DeviceContext &dev_ctx,
......@@ -45,18 +55,21 @@ class MemcpyFunctor {
void operator()(const framework::LoDTensor &lod_tensor) const {
auto &out_tensor = *out_->GetMutable<framework::LoDTensor>();
if (dst_place_type_ == 2) {
if (dst_place_type_ == DeviceType::CUDA_PINNED) {
framework::TensorCopy(lod_tensor, platform::CUDAPinnedPlace(), dev_ctx_,
&out_tensor);
} else if (dst_place_type_ == 1) {
} else if (dst_place_type_ == DeviceType::CUDA) {
framework::TensorCopy(lod_tensor, dev_ctx_.GetPlace(), dev_ctx_,
&out_tensor);
} else if (dst_place_type_ == 0) {
} else if (dst_place_type_ == DeviceType::CPU) {
framework::TensorCopySync(lod_tensor, platform::CPUPlace(), &out_tensor);
#ifdef PADDLE_WITH_ASCEND_CL
} else if (dst_place_type_ == 4) {
} else if (dst_place_type_ == DeviceType::NPU) { /* npu_pin->npu */
framework::TensorCopy(lod_tensor, dev_ctx_.GetPlace(), dev_ctx_,
&out_tensor);
} else if (dst_place_type_ == DeviceType::NPU_PINNED) { /* npu->npu_pin */
framework::TensorCopy(lod_tensor, platform::NPUPinnedPlace(), dev_ctx_,
&out_tensor);
#endif
} else {
PADDLE_THROW(platform::errors::Unimplemented(
......
......@@ -20,10 +20,36 @@ from .shard import Shard
__all__ = []
class PlaceType:
# sync with memcpy op, maybe not a good design
CPU = 0
CUDA = 1
CUDA_PINNED = 2
XPU = 3 # unsupport for now
NPU = 4
NPU_PINNED = 5
@staticmethod
def default_device():
if core.is_compiled_with_cuda():
return PlaceType.CUDA
elif core.is_compiled_with_npu():
return PlaceType.NPU
return PlaceType.CPU
@staticmethod
def default_pinned():
if core.is_compiled_with_cuda():
return PlaceType.CUDA_PINNED
elif core.is_compiled_with_npu():
return PlaceType.NPU_PINNED
return PlaceType.CPU
class OffloadHelper(object):
cpu_place_type = 0
cuda_place_type = 1
cuda_pinned_place_type = 2
cuda_place_type = PlaceType.default_device()
cuda_pinned_place_type = PlaceType.default_pinned()
def __init__(self, mp_ring_id=None, dp_ring_id=None):
self.mp_ring_id = mp_ring_id
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册