未验证 提交 7aaaa1c6 编写于 作者: R ronnywang 提交者: GitHub

Add unified device management api (#48651)

* [CustomDevice] add custom device api

* update

* update

* test=document_fix

* update

* update

* add  examples
上级 5d110365
......@@ -76,7 +76,7 @@ bool DeviceEventQueryCustomDevice(const DeviceEvent* event) {
void DeviceEventFinishCustomDevice(const DeviceEvent* event) {
auto* wrapper =
static_cast<CustomDeviceEventWrapper*>(event->GetEvent().get());
wrapper->inner_event_->Synchonrize();
wrapper->inner_event_->Synchronize();
}
void DeviceEventCustomDeviceWaitCustomDevice(const DeviceEvent* event,
......
......@@ -138,6 +138,7 @@ set(PYBIND_SRCS
generator_py.cc
communication.cc
cuda_streams_py.cc
custom_device_py.cc
xpu_streams_py.cc
jit.cc
auto_parallel_py.cc)
......
......@@ -243,6 +243,10 @@ void BindCudaStream(py::module *m_ptr) {
print(ptr)
)DOC")
.def_property_readonly("place",
[](phi::CUDAStream &self) {
return platform::CUDAPlace(self.place());
})
#endif
.def(
"__init__",
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/pybind/custom_device_py.h"
#include <string>
#include <vector>
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/backends/device_manager.h"
#include "paddle/phi/backends/event.h"
#include "paddle/phi/backends/stream.h"
namespace py = pybind11;
namespace paddle {
namespace pybind {
void BindCustomDevicePy(py::module *m_ptr) {
auto &m = *m_ptr;
// Bind Methods
m.def(
"_get_current_custom_device_stream",
[](const std::string &device_type, int device_id) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto place = paddle::platform::CustomPlace(
device_type,
device_id == -1 ? phi::DeviceManager::GetDevice(device_type)
: device_id);
return static_cast<const phi::CustomContext *>(
paddle::platform::DeviceContextPool::Instance().Get(place))
->GetStream();
#else
PADDLE_THROW(platform::errors::Unavailable(
"Paddle is not compiled with CustomDevice. "
"Cannot visit _get_current_custom_device_stream."));
#endif
},
py::return_value_policy::reference,
py::arg("device_type"),
py::arg("device_id") = -1);
m.def(
"_set_current_custom_device_stream",
[](const std::string &device_type,
int device_id,
std::shared_ptr<phi::stream::Stream> stream) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto place = paddle::platform::CustomPlace(
device_type,
device_id == -1 ? phi::DeviceManager::GetDevice(device_type)
: device_id);
static_cast<phi::CustomContext *>(
paddle::platform::DeviceContextPool::Instance().Get(place))
->SetStream(stream);
return stream;
#else
PADDLE_THROW(platform::errors::Unavailable(
"Paddle is not compiled with CustomDevice. "
"Cannot visit _set_current_custom_device_stream."));
#endif
},
py::arg("device_type"),
py::arg("device_id") = -1,
py::arg("stream") = nullptr);
m.def("_synchronize_custom_device",
[](const std::string &device_type, int device_id) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto place = paddle::platform::CustomPlace(
device_type,
device_id == -1 ? phi::DeviceManager::GetDevice(device_type)
: device_id);
phi::DeviceManager::SynchronizeDevice(place);
#else
PADDLE_THROW(platform::errors::Unavailable(
"Paddle is not compiled with CustomDevice. "
"Cannot visit _synchronize_custom_device."));
#endif
});
py::class_<phi::stream::Stream, std::shared_ptr<phi::stream::Stream>>(
m, "CustomDeviceStream", R"DOC(
The handle of the custom device stream.
Parameters:
device(paddle.CustomPlace()|str): The device which wanted to allocate the stream.
device_id(int, optional): The id of the device which wanted to allocate the stream.
If device is None or negative integer, device will be the current device.
If device is positive integer, it must less than the device count. Default: None.
priority(int|None, optional): The priority of stream. The priority can be 1(high) or 2(normal).
If priority is None, the priority is 2(normal). Default: None.
blocking(int|None, optional): Whether the stream is executed synchronously. Default: False.
Examples:
.. code-block:: python
# required: custom_device
import paddle
s3 = paddle.device.custom.Stream('custom_cpu')
s2 = paddle.device.custom.Stream('custom_cpu', 0)
s1 = paddle.device.custom.Stream(paddle.CustomPlace('custom_cpu'))
s1 = paddle.device.custom.Stream(paddle.CustomPlace('custom_cpu'), 1)
s1 = paddle.device.custom.Stream(paddle.CustomPlace('custom_cpu'), 1, True)
)DOC")
.def(
"__init__",
[](phi::stream::Stream &self,
const platform::CustomPlace &place,
int priority,
bool blocking) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
new (&self) phi::stream::Stream();
self.Init(
place,
static_cast<phi::stream::Stream::Priority>(priority),
static_cast<phi::stream::Stream::Flag>(
blocking ? phi::stream::Stream::Flag::kDefaultFlag
: phi::stream::Stream::Flag::kStreamNonBlocking));
#else
PADDLE_THROW(platform::errors::Unavailable(
"Paddle is not compiled with CustomDevice. "
"Cannot visit CustomDeviceStream."));
#endif
},
py::arg("device"),
py::arg("priority") = 2,
py::arg("blocking") = false)
.def(
"__init__",
[](phi::stream::Stream &self,
const std::string &device_type,
int device_id,
int priority,
bool blocking) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
new (&self) phi::stream::Stream();
self.Init(
phi::CustomPlace(
device_type,
device_id == -1 ? phi::DeviceManager::GetDevice(device_type)
: device_id),
static_cast<phi::stream::Stream::Priority>(priority),
static_cast<phi::stream::Stream::Flag>(
blocking ? phi::stream::Stream::Flag::kDefaultFlag
: phi::stream::Stream::Flag::kStreamNonBlocking));
#else
PADDLE_THROW(platform::errors::Unavailable(
"Paddle is not compiled with CustomDevice. "
"Cannot visit CustomDeviceStream."));
#endif
},
py::arg("device"),
py::arg("device_id") = -1,
py::arg("priority") = 2,
py::arg("blocking") = false)
.def(
"wait_event",
[](const phi::stream::Stream &self, phi::event::Event *event) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
self.WaitEvent(event);
#else
PADDLE_THROW(platform::errors::Unavailable(
"Paddle is not compiled with CustomDevice. "
"Cannot visit CustomDeviceStream."));
#endif
},
R"DOC(
Makes all future work submitted to stream wait for all work captured in event.
Parameters:
event(CustomDeviceEvent): The event to wait on.
Examples:
.. code-block:: python
# required: custom_device
import paddle
place = paddle.CustomPlace('custom_cpu', 0)
s = paddle.device.custom.Stream(place)
event = paddle.device.custom.Event(place)
s.wait_event(event)
)DOC")
.def(
"wait_stream",
[](const phi::stream::Stream &self, phi::stream::Stream *other) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
phi::event::Event event;
event.Init(self.GetPlace());
event.Record(other);
self.WaitEvent(&event);
#else
PADDLE_THROW(platform::errors::Unavailable(
"Paddle is not compiled with CustomDevice. "
"Cannot visit CustomDeviceStream."));
#endif
},
R"DOC(
Synchronizes with the given stream.
Parameters:
stream(CUDAStream): The stream to synchronize with.
Examples:
.. code-block:: python
# required: custom_device
import paddle
place = paddle.CustomPlace('custom_cpu', 0)
s1 = paddle.device.custom.Stream(place)
s2 = paddle.device.custom.Stream(place)
s1.wait_stream(s2)
)DOC")
.def(
"query",
[](const phi::stream::Stream &self) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
return self.Query();
#else
PADDLE_THROW(platform::errors::Unavailable(
"Paddle is not compiled with CustomDevice. "
"Cannot visit CustomDeviceStream."));
#endif
},
R"DOC(
Return the status whether if all operations in stream have completed.
Returns: A boolean value.
Examples:
.. code-block:: python
# required: custom_device
import paddle
place = paddle.CustomPlace('custom_cpu', 0)
s = paddle.device.custom.Stream(place)
is_done = s.query()
)DOC")
.def(
"synchronize",
[](const phi::stream::Stream &self) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
self.Synchronize();
#else
PADDLE_THROW(platform::errors::Unavailable(
"Paddle is not compiled with CustomDevice. "
"Cannot visit CustomDeviceStream."));
#endif
},
R"DOC(
Waits for stream tasks to complete.
Examples:
.. code-block:: python
# required: custom_device
import paddle
place = paddle.CustomPlace('custom_cpu', 0)
s = paddle.device.custom.Stream(place)
s.synchronize()
)DOC")
.def(
"record_event",
[](const phi::stream::Stream &self, phi::event::Event *event) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
if (event == nullptr) {
event = new phi::event::Event;
event->Init(self.GetPlace());
}
event->Record(&self);
return event;
#else
PADDLE_THROW(platform::errors::Unavailable(
"Paddle is not compiled with CustomDevice. "
"Cannot visit CustomDeviceStream."));
#endif
},
R"DOC(
Record an event in the stream.
Parameters:
event(CustomDeviceEvent, optional): The event to be record. If event is None, a new event is created.
Default: None.
Returns:
The recored event.
Examples:
.. code-block:: python
# required: custom_device
import paddle
place = paddle.CustomPlace('custom_cpu', 0)
s = paddle.device.custom.Stream(place)
event = s.record_event()
)DOC",
py::arg("event") = nullptr)
.def_property_readonly(
"raw_stream",
[](const phi::stream::Stream &self) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
VLOG(10) << self.raw_stream();
return reinterpret_cast<std::uintptr_t>(self.raw_stream());
#else
PADDLE_THROW(platform::errors::Unavailable(
"Paddle is not compiled with CustomDevice. "
"Cannot visit CustomDeviceStream."));
#endif
},
R"DOC(
return the raw stream of type CustomDeviceStream as type int.
Examples:
.. code-block:: python
# required: custom_device
import paddle
import ctypes
stream = paddle.device.custom.current_stream().raw_stream
print(stream)
ptr = ctypes.c_void_p(stream) # convert back to void*
print(ptr)
)DOC")
.def_property_readonly("place", [](const phi::stream::Stream &self) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
return reinterpret_cast<const phi::CustomPlace &>(self.GetPlace());
#else
PADDLE_THROW(platform::errors::Unavailable(
"Paddle is not compiled with CustomDevice. "
"Cannot visit CustomDeviceStream."));
#endif
});
py::class_<phi::event::Event, std::shared_ptr<phi::event::Event>>(
m, "CustomDeviceEvent", R"DOC(
The handle of the custom device event.
Parameters:
device(paddle.CustomPlace()|str): The device which wanted to allocate the stream.
device_id(int, optional): The id of the device which wanted to allocate the stream.
If device is None or negative integer, device will be the current device.
If device is positive integer, it must less than the device count. Default: None.
enable_timing(bool, optional): Whether the event will measure time. Default: False.
blocking(bool, optional): Whether the wait() func will be blocking. Default: False;
interprocess(bool, optional): Whether the event can be shared between processes. Default: False.
Examples:
.. code-block:: python
# required: custom_device
import paddle
place = paddle.CustomPlace('custom_cpu', 0)
event = paddle.device.custom.Event(place)
)DOC")
.def(
"__init__",
[](phi::event::Event &self,
const platform::CustomPlace &place,
bool enable_timing,
bool blocking,
bool interprocess) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto flag = static_cast<phi::event::Event::Flag>(
static_cast<uint32_t>(
enable_timing ? 0
: phi::event::Event::Flag::DisableTiming) |
static_cast<uint32_t>(
!blocking ? 0 : phi::event::Event::Flag::BlockingSync) |
static_cast<uint32_t>(
!interprocess ? 0 : phi::event::Event::Flag::Interprocess)
);
new (&self) phi::event::Event();
self.Init(place, flag);
#else
PADDLE_THROW(platform::errors::Unavailable(
"Paddle is not compiled with CustomDevice. "
"Cannot visit CustomDeviceEvent."));
#endif
},
py::arg("device"),
py::arg("enable_timing") = false,
py::arg("blocking") = false,
py::arg("interprocess") = false)
.def(
"__init__",
[](phi::event::Event &self,
const std::string &device_type,
int device_id,
bool enable_timing,
bool blocking,
bool interprocess) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto flag = static_cast<phi::event::Event::Flag>(
static_cast<uint32_t>(
enable_timing ? 0
: phi::event::Event::Flag::DisableTiming) |
static_cast<uint32_t>(
!blocking ? 0 : phi::event::Event::Flag::BlockingSync) |
static_cast<uint32_t>(
!interprocess ? 0 : phi::event::Event::Flag::Interprocess)
);
new (&self) phi::event::Event();
self.Init(
phi::CustomPlace(
device_type,
device_id == -1 ? phi::DeviceManager::GetDevice(device_type)
: device_id),
flag);
#else
PADDLE_THROW(platform::errors::Unavailable(
"Paddle is not compiled with CustomDevice. "
"Cannot visit CustomDeviceEvent."));
#endif
},
py::arg("device"),
py::arg("device_id") = -1,
py::arg("enable_timing") = false,
py::arg("blocking") = false,
py::arg("interprocess") = false)
.def(
"record",
[](phi::event::Event &self, phi::stream::Stream *stream) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
if (stream == nullptr) {
stream = static_cast<const phi::CustomContext *>(
paddle::platform::DeviceContextPool::Instance().Get(
self.GetPlace()))
->GetStream()
.get();
}
self.Record(stream);
#else
PADDLE_THROW(platform::errors::Unavailable(
"Paddle is not compiled with CustomDevice. "
"Cannot visit CustomDeviceEvent."));
#endif
},
R"DOC(
Records the event in the given stream.
Parameters:
stream(CustomDeviceStream, optional): The handle of custom device stream. If None, the stream is the current stream. Default: None.
Examples:
.. code-block:: python
# required: custom_device
import paddle
place = paddle.CustomPlace('custom_cpu', 0)
event = paddle.device.custom.Event(place)
event.record()
)DOC")
.def(
"query",
[](const phi::event::Event &self) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
return self.Query();
#else
PADDLE_THROW(platform::errors::Unavailable(
"Paddle is not compiled with CustomDevice. "
"Cannot visit CustomDeviceEvent."));
#endif
},
R"DOC(
Queries the event's status.
Returns: A boolean which indicates all work currently captured by the event has been completed.
Examples:
.. code-block:: python
# required: custom_device
import paddle
place = paddle.CustomPlace('custom_cpu', 0)
event = paddle.device.cuda.Event(place)
is_done = event.query()
)DOC")
.def(
"synchronize",
[](const phi::event::Event &self) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
self.Synchronize();
#else
PADDLE_THROW(platform::errors::Unavailable(
"Paddle is not compiled with CustomDevice. "
"Cannot visit CustomDeviceEvent."));
#endif
},
R"DOC(
Waits for an event to complete.
Examples:
.. code-block:: python
# required: custom_device
import paddle
place = paddle.CustomPlace('custom_cpu', 0)
event = paddle.device.custom.Event(place)
event.synchronize()
)DOC")
.def_property_readonly(
"raw_event",
[](const phi::event::Event &self) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
VLOG(10) << self.raw_event();
return reinterpret_cast<std::uintptr_t>(self.raw_event());
#else
PADDLE_THROW(platform::errors::Unavailable(
"Paddle is not compiled with CustomDevice. "
"Cannot visit CustomDeviceEvent."));
#endif
},
R"DOC(
return the raw event of type CustomDeviceEvent as type int.
Examples:
.. code-block:: python
# required: custom_device
import paddle
import ctypes
place = paddle.CustomPlace('custom_cpu', 0)
event = paddle.device.custom.Event(place)
raw_event = event.raw_event
print(raw_event)
ptr = ctypes.c_void_p(raw_event) # convert back to void*
print(ptr)
)DOC")
.def_property_readonly("place", [](const phi::event::Event &self) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
return reinterpret_cast<const phi::CustomPlace &>(self.GetPlace());
#else
PADDLE_THROW(platform::errors::Unavailable(
"Paddle is not compiled with CustomDevice. "
"Cannot visit CustomDeviceEvent."));
#endif
});
}
} // namespace pybind
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
namespace py = pybind11;
namespace paddle {
namespace pybind {
void BindCustomDevicePy(py::module* m);
} // namespace pybind
} // namespace paddle
......@@ -88,6 +88,7 @@ limitations under the License. */
#include "paddle/fluid/platform/profiler/event_tracing.h"
#include "paddle/fluid/platform/profiler/profiler.h"
#include "paddle/fluid/pybind/cuda_streams_py.h"
#include "paddle/fluid/pybind/custom_device_py.h"
#include "paddle/fluid/pybind/distributed_py.h"
#include "paddle/fluid/pybind/eager.h"
#include "paddle/fluid/pybind/imperative.h"
......@@ -629,6 +630,7 @@ PYBIND11_MODULE(libpaddle, m) {
BindCudaStream(&m);
BindXpuStream(&m);
BindJit(&m);
BindCustomDevicePy(&m);
// Not used, just make sure cpu_info.cc is linked.
phi::backends::cpu::CpuTotalPhysicalMemory();
......
......@@ -36,8 +36,7 @@ if(WITH_MKLDNN)
list(APPEND BACKENDS_DEPS mkldnn)
endif()
if(WITH_CUSTOM_DEVICE)
list(
list(
APPEND
BACKENDS_SRCS
callback_manager.cc
......@@ -45,9 +44,10 @@ if(WITH_CUSTOM_DEVICE)
stream.cc
event.cc
device_base.cc
device_manager.cc
custom/custom_context.cc
custom/custom_device.cc)
device_manager.cc)
if(WITH_CUSTOM_DEVICE)
list(APPEND BACKENDS_SRCS custom/custom_context.cc custom/custom_device.cc)
endif()
add_library(phi_backends "${BACKENDS_SRCS}")
......
......@@ -13,16 +13,6 @@
// limitations under the License.
#pragma once
#ifdef PADDLE_WITH_CUDA
#include <cuda.h>
#include <cuda_runtime.h>
#endif
#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
#endif
#include <functional>
#include <future> // NOLINT
#include <memory>
......
......@@ -36,6 +36,12 @@ struct CustomContext::Impl {
return reinterpret_cast<void*>(stream_->raw_stream());
}
std::shared_ptr<phi::stream::Stream> GetStream() const { return stream_; }
void SetStream(std::shared_ptr<phi::stream::Stream> stream) {
stream_ = stream;
}
void Wait() const { stream_->Wait(); }
Place place_;
......@@ -49,6 +55,14 @@ const Place& CustomContext::GetPlace() const { return impl_->GetPlace(); }
void* CustomContext::stream() const { return impl_->stream(); }
std::shared_ptr<phi::stream::Stream> CustomContext::GetStream() const {
return impl_->GetStream();
}
void CustomContext::SetStream(std::shared_ptr<phi::stream::Stream> stream) {
impl_->SetStream(stream);
}
void CustomContext::Wait() const { return impl_->Wait(); }
CustomContext::CustomContext(const CustomPlace& place)
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include <memory>
#include "paddle/phi/backends/stream.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/device_context.h"
......@@ -30,9 +31,14 @@ class CustomContext : public DeviceContext,
const Place& GetPlace() const override;
/*! \brief Return stream in the device context. */
/*! \brief Return raw stream in the device context. */
void* stream() const;
/*! \brief Return stream in the device context. */
std::shared_ptr<phi::stream::Stream> GetStream() const;
void SetStream(std::shared_ptr<phi::stream::Stream> stream);
// Wait for all operations completion in the stream.
void Wait() const override;
......
......@@ -146,13 +146,6 @@ class CustomDevice : public DeviceInterface {
stream::Stream::Priority::kNormal,
const stream::Stream::Flag& flag =
stream::Stream::Flag::kDefaultFlag) override {
if (priority != stream::Stream::Priority::kNormal ||
flag != stream::Stream::Flag::kDefaultFlag) {
PADDLE_THROW(phi::errors::Unavailable(
"priority != stream::Stream::Priority::kNormal || flag != "
"stream::Stream::Flag::kDefaultFlag is not allowed on "
"CustomDevice."));
}
const auto device = &devices_pool[dev_id];
C_Stream c_stream;
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(
......
......@@ -13,7 +13,6 @@
// limitations under the License.
#pragma once
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include <vector>
#include "paddle/phi/backends/c_comm_lib.h"
......@@ -275,5 +274,3 @@ class DeviceInterface { // Driver / Runtime
};
} // namespace phi
#endif
......@@ -13,8 +13,6 @@
// limitations under the License.
#pragma once
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/phi/backends/device_manager.h"
namespace phi {
......@@ -46,5 +44,3 @@ class DeviceGuard {
};
} // namespace phi
#endif
......@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/phi/backends/device_manager.h"
#include "paddle/phi/common/complex.h"
......@@ -663,6 +662,8 @@ std::vector<std::string> ListAllLibraries(const std::string& library_dir) {
std::vector<std::string> libraries;
std::regex express(".*\\.so");
std::match_results<std::string::iterator> results;
#if !defined(_WIN32)
DIR* dir = nullptr;
dirent* ptr = nullptr;
......@@ -680,9 +681,9 @@ std::vector<std::string> ListAllLibraries(const std::string& library_dir) {
}
closedir(dir);
}
#endif
return libraries;
}
} // namespace phi
#endif
......@@ -13,7 +13,6 @@
// limitations under the License.
#pragma once
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include <unordered_map>
......@@ -285,12 +284,14 @@ class DeviceManager {
std::vector<std::string> ListAllLibraries(const std::string& library_dir);
#ifdef PADDLE_WITH_CUSTOM_DEVICE
void LoadCustomRuntimeLib(const std::string& dso_lib_path, void* dso_handle);
void LoadCustomRuntimeLib(const CustomRuntimeParams& runtime_params,
std::unique_ptr<C_DeviceInterface> device_interface,
const std::string& dso_lib_path,
void* dso_handle);
#endif
class Registrar {
public:
......@@ -303,5 +304,3 @@ class Registrar {
};
} // namespace phi
#endif
......@@ -59,7 +59,7 @@ void Event::Record(const stream::Stream* stream) { stream->RecordEvent(this); }
bool Event::Query() const { return device_->QueryEvent(this); }
void Event::Synchonrize() const { device_->SynchronizeEvent(this); }
void Event::Synchronize() const { device_->SynchronizeEvent(this); }
const Place& Event::GetPlace() const { return place_; }
......
......@@ -46,7 +46,7 @@ class Event {
void Destroy();
void Record(const stream::Stream* stream);
bool Query() const;
void Synchonrize() const;
void Synchronize() const;
const Place& GetPlace() const;
private:
......
......@@ -15,6 +15,8 @@
# TODO: define the functions to manipulate devices
import re
import os
import ctypes
import paddle
from paddle.fluid import core
from paddle.fluid import framework
from paddle.fluid.dygraph.parallel import ParallelEnv
......@@ -43,6 +45,12 @@ __all__ = [ # noqa
'get_all_custom_device_type',
'get_available_device',
'get_available_custom_device',
'Stream',
'Event',
'current_stream',
'set_stream',
'stream_guard',
'synchronize',
]
_cudnn_version = None
......@@ -514,3 +522,475 @@ def get_available_custom_device():
# Output: ['CustomCPU', 'CustomGPU:0', 'CustomGPU:1']
"""
return core.get_available_custom_device()
class Event(object):
'''
A device event wrapper around StreamBase.
Parameters:
device(str|paddle.CUDAPlace(n)|paddle.CustomPlace(n)): Which device the stream runn on. If device is None, the device is the current device. Default: None.
It can be ``gpu``, ``gpu:x``,``custom_device``, ``custom_device:x``, where ``custom_device`` is the name of CustomDevicec,
where ``x`` is the index of the GPUs, XPUs, NPUs or MLUs. And it can be paddle.CUDAPlace(n) or paddle.CustomPlace(n).
enable_timing (bool, optional): indicates if the event should measure time, default is False
blocking (bool, optional): if True, ``wait`` will be blocking, default is False
interprocess (bool): if True, the event can be shared between processes, default is False
Returns:
Event: The event.
Examples:
.. code-block:: python
# required: custom_device
import paddle
e1 = paddle.device.Event()
e2 = paddle.device.Event('custom_cpu')
e3 = paddle.device.Event('custom_cpu:0')
e4 = paddle.device.Event(paddle.CustomPlace('custom_cpu', 0))
'''
def __init__(
self,
device=None,
enable_timing=False,
blocking=False,
interprocess=False,
):
if device is None:
self.device = paddle.framework._current_expected_place()
elif isinstance(device, str):
self.device = paddle.device._convert_to_place(device)
else:
self.device = device
if paddle.is_compiled_with_cuda() and isinstance(
self.device, paddle.CUDAPlace
):
self.event_base = core.CUDAEvent(
enable_timing, blocking, interprocess
)
elif isinstance(self.device, paddle.CustomPlace):
self.event_base = core.CustomDeviceEvent(
self.device.get_device_type(),
self.device.get_device_id(),
enable_timing,
blocking,
interprocess,
)
else:
raise TypeError(
"device should be gpu, xpu, {}".format(
",".join(paddle.device.get_all_custom_device_type())
)
)
def record(self, stream=None):
'''
Records the event in a given stream.
Parameters:
stream(Stream, optional): The given stream. By default, stream is None,
event will be recorded in current_stream.
Returns:
None.
Examples:
.. code-block:: python
# required: custom_device
import paddle
e = paddle.device.Event()
e.record()
s = paddle.device.Stream()
e.record(s)
'''
if stream is None:
stream = current_stream(self.device)
self.event_base.record(stream.stream_base)
def query(self):
'''
Checks if all work currently captured by event has completed.
Returns:
bool: Whether all work currently captured by event has completed.
Examples:
.. code-block:: python
# required: custom_device
import paddle
e = paddle.device.Event()
e.query()
'''
return self.event_base.query()
def elapsed_time(self, end_event):
'''
Returns the time elapsed in milliseconds after the event was
recorded and before the end_event was recorded.
Returns:
int: The time.
Examples:
.. code-block:: python
# required: custom_device
import paddle
e1 = paddle.device.Event()
e2 = paddle.device.Event()
e1.elapsed_time(e2)
'''
return 0
def synchronize(self):
'''
Waits for the event to complete.
Waits until the completion of all work currently captured in this event.
This prevents the CPU thread from proceeding until the event completes.
Returns:
None.
Examples:
.. code-block:: python
# required: custom_device
import paddle
e = paddle.device.Event()
e.synchronize()
'''
self.event_base.synchronize()
def __repr__(self):
return self.event_base
class Stream(object):
'''
A device stream wrapper around StreamBase.
Parameters:
device(str|paddle.CUDAPlace(n)|paddle.CustomPlace(n)): Which device the stream runn on. If device is None, the device is the current device. Default: None.
It can be ``gpu``, ``gpu:x``,``custom_device``, ``custom_device:x``, where ``custom_device`` is the name of CustomDevicec,
where ``x`` is the index of the GPUs, XPUs, NPUs or MLUs. And it can be paddle.CUDAPlace(n) or paddle.CustomPlace(n).
priority(int, optional): priority of the CUDA stream. Can be either
1 (high priority) or 2 (low priority). By default, streams have
priority 2.
Returns:
Stream: The stream.
Examples:
.. code-block:: python
# required: custom_device
import paddle
s1 = paddle.device.Stream()
s2 = paddle.device.Stream('custom_cpu')
s3 = paddle.device.Stream('custom_cpu:0')
s4 = paddle.device.Stream(paddle.CustomPlace('custom_cpu', 0))
'''
def __init__(self, device=None, priority=2, stream_base=None):
if stream_base is not None:
if isinstance(
stream_base, (core.CUDAStream, core.CustomDeviceStream)
):
self.stream_base = stream_base
self.device = stream_base.place
else:
raise TypeError(
"stream_base should be CUDAStream, CustomDeviceStream"
)
return
if device is None:
self.device = paddle.framework._current_expected_place()
elif isinstance(device, str):
self.device = paddle.device._convert_to_place(device)
else:
self.device = device
if paddle.is_compiled_with_cuda() and isinstance(
self.device, paddle.CUDAPlace
):
self.stream_base = core.CUDAStream(
self.device.get_device_id(), priority
)
elif isinstance(self.device, paddle.CustomPlace):
self.stream_base = core.CustomDeviceStream(
self.device.get_device_type(),
self.device.get_device_id(),
priority,
blocking=False,
)
else:
raise TypeError(
"device should be gpu, xpu, {}".format(
",".join(paddle.device.get_all_custom_device_type())
)
)
def wait_event(self, event):
'''
Makes all future work submitted to the stream wait for an event.
Parameters:
event (Event): an event to wait for.
Returns:
None.
Examples:
.. code-block:: python
# required: custom_device
import paddle
s = paddle.device.Stream()
e = paddle.device.Event()
s.wait_event(e)
'''
self.stream_base.wait_event(event.event_base)
def wait_stream(self, stream):
'''
Synchronizes with another stream.
All future work submitted to this stream will wait until all kernels
submitted to a given stream at the time of call complete.
Parameters:
stream (Stream): a stream to synchronize.
Returns:
None.
Examples:
.. code-block:: python
# required: custom_device
import paddle
s1 = paddle.device.Stream()
s2 = paddle.device.Stream()
s1.wait_stream(s2)
'''
self.stream_base.wait_stream(stream.stream_base)
def record_event(self, event=None):
'''
Records an event.
Parameters:
event (Event, optional): event to record. If not given, a new one
will be allocated.
Returns:
Event: Recorded event.
Examples:
.. code-block:: python
# required: custom_device
import paddle
s = paddle.device.Stream()
e1 = s.record_event()
e2 = paddle.device.Event()
s.record_event(e2)
'''
if event is None:
event = Event(self.device)
event.record(self)
return event
def query(self):
'''
Checks if all the work submitted has been completed.
Returns:
bool: Whether all kernels in this stream are completed.
Examples:
.. code-block:: python
# required: custom_device
import paddle
s = paddle.device.Stream()
s.query()
'''
return self.stream_base.query()
def synchronize(self):
'''
Wait for all the kernels in this stream to complete.
Returns:
None.
Examples:
.. code-block:: python
# required: custom_device
import paddle
s = paddle.device.Stream()
s.synchronize()
'''
self.stream_base.synchronize()
@property
def _as_parameter_(self):
if isinstance(self.stream_base, core.CUDAStream):
return ctypes.c_void_p(self.stream_base.cuda_stream)
else:
return ctypes.c_void_p(self.stream_base.raw_stream)
def __eq__(self, o):
if isinstance(o, Stream):
return super(Stream, self).__eq__(o)
return False
def __hash__(self):
return hash((self.stream_base, self.device))
def __repr__(self):
return '<paddle.device.Stream device={0} stream={1:#x}>'.format(
self.device, self._as_parameter_.value
)
def current_stream(device=None):
'''
Return the current stream by the device.
Parameters:
device(str|paddle.CUDAPlace(n)|paddle.CustomPlace(n)): The device which want to get stream from. If device is None, the device is the current device. Default: None.
It can be ``gpu``, ``gpu:x``,``custom_device``, ``custom_device:x``, where ``custom_device`` is the name of CustomDevicec,
where ``x`` is the index of the GPUs, CustomDevicecs. And it can be paddle.CUDAPlace(n) or paddle.CustomPlace(n).
Returns:
Stream: The stream to the device.
Examples:
.. code-block:: python
# required: custom_device
import paddle
s1 = paddle.device.current_stream()
s2 = paddle.device.current_stream("gpu:0")
place = paddle.CustomPlace('custom_cpu', 0)
s3 = paddle.device.current_stream(place)
'''
if device is None:
place = paddle.framework._current_expected_place()
elif isinstance(device, str):
place = paddle.device._convert_to_place(device)
else:
place = device
if paddle.is_compiled_with_cuda() and isinstance(place, paddle.CUDAPlace):
return Stream(
stream_base=core._get_current_stream(place.get_device_id())
)
elif isinstance(place, paddle.CustomPlace):
return Stream(
stream_base=core._get_current_custom_device_stream(
place.get_device_type(), place.get_device_id()
)
)
else:
raise TypeError(
"device should be gpu, xpu, {}".format(
",".join(paddle.device.get_all_custom_device_type())
)
)
def set_stream(stream):
'''
Set the current stream.
Parameters:
stream(Stream): The selected stream.
Returns:
Stream: The previous stream.
Examples:
.. code-block:: python
# required: custom_device
import paddle
s = paddle.device.Stream()
paddle.device.set_stream(s)
'''
prev_stream = current_stream(stream.stream_base.place)
if paddle.is_compiled_with_cuda() and isinstance(
stream.stream_base.place, paddle.CUDAPlace
):
core._set_current_stream(stream.stream_base)
elif isinstance(stream.stream_base.place, paddle.CustomPlace):
core._set_current_custom_device_stream(
stream.stream_base.place.get_device_type(),
stream.stream_base.place.get_device_id(),
stream.stream_base,
)
else:
raise TypeError(
"device should be gpu, xpu, {}".format(
",".join(paddle.device.get_all_custom_device_type())
)
)
return prev_stream
class stream_guard(object):
'''
Notes:
This API only supports dynamic graph mode currently.
A context manager that specifies the current stream context by the given stream.
Parameters:
stream(Stream, optional): the selected stream. If stream is None, just yield.
Returns:
None.
Examples:
.. code-block:: python
# required: custom_device
import paddle
s = paddle.device.Stream()
data1 = paddle.ones(shape=[20])
data2 = paddle.ones(shape=[20])
data3 = data1 + data2
with paddle.device.stream_guard(s):
s.wait_stream(paddle.device.default_stream())
data4 = data1 + data3
'''
def __init__(self, stream=None):
self.stream = stream
def __enter__(self):
cur_stream = self.stream
if cur_stream is None:
return
self.src_prev_stream = current_stream(cur_stream.device)
if self.src_prev_stream.device != cur_stream.device:
self.tmp_place = paddle.fluid.framework._current_expected_place()
paddle.fluid.framework._set_expected_place(cur_stream.device)
self.dst_prev_stream = current_stream(cur_stream.device)
set_stream(cur_stream)
else:
set_stream(cur_stream)
def __exit__(self, *args):
cur_stream = self.stream
if cur_stream is None:
return
if self.src_prev_stream.device != cur_stream.device:
set_stream(self.dst_prev_stream)
paddle.fluid.framework._set_expected_place(self.tmp_place)
set_stream(self.src_prev_stream)
else:
set_stream(self.src_prev_stream)
def synchronize(device=None):
'''
Wait for the compute on the given device to finish.
Parameters:
device(str|paddle.CUDAPlace(n)|paddle.XPUPlace(n)|paddle.CustomPlace(n)): The device which want to wait for. If device is None, the device is the current device. Default: None.
It can be ``gpu``, ``gpu:x``, ``xpu``, ``xpu:x``, ``custom_device``, ``custom_device:x``, where ``custom_device`` is the name of CustomDevicec,
where ``x`` is the index of the GPUs, XPUs, NPUs or MLUs. And it can be paddle.CUDAPlace(n) or paddle.XPUPlace(n) or paddle.CustomPlace(n).
Examples:
.. code-block:: python
# required: custom_device
import paddle
paddle.device.synchronize()
paddle.device.synchronize("gpu:0")
place = paddle.CustomPlace('custom_cpu', 0)
paddle.device.synchronize(place)
'''
if device is None:
place = paddle.framework._current_expected_place()
elif isinstance(device, str):
place = paddle.device._convert_to_place(device)
else:
place = device
if paddle.is_compiled_with_cuda() and isinstance(place, paddle.CUDAPlace):
core._device_synchronize(place.get_device_id())
elif paddle.is_compiled_with_xpu() and isinstance(place, paddle.XPUPlace):
core._xpu_device_synchronize(place.get_device_id())
elif isinstance(place, paddle.CustomPlace):
core._synchronize_custom_device(
place.get_device_type(), place.get_device_id()
)
else:
raise TypeError(
"device should be gpu, xpu, {}".format(
",".join(paddle.device.get_all_custom_device_type())
)
)
......@@ -15,6 +15,7 @@
import paddle
from paddle.fluid import core
from paddle.fluid.wrapped_decorator import signature_safe_contextmanager
from paddle.utils import deprecated
from .streams import Stream # noqa: F401
from .streams import Event # noqa: F401
......@@ -37,6 +38,12 @@ __all__ = [
]
@deprecated(
since="2.5.0",
update_to="paddle.device.current_stream",
level=1,
reason="current_stream in paddle.device.cuda will be removed in future",
)
def current_stream(device=None):
'''
Return the current CUDA stream by the device.
......@@ -75,6 +82,12 @@ def current_stream(device=None):
return core._get_current_stream(device_id)
@deprecated(
since="2.5.0",
update_to="paddle.device.synchronize",
level=1,
reason="synchronize in paddle.device.cuda will be removed in future",
)
def synchronize(device=None):
'''
Wait for the compute on the given CUDA device to finish.
......@@ -352,6 +365,12 @@ def _set_current_stream(stream):
return core._set_current_stream(stream)
@deprecated(
since="2.5.0",
update_to="paddle.device.stream_guard",
level=1,
reason="stream_guard in paddle.device.cuda will be removed in future",
)
@signature_safe_contextmanager
def stream_guard(stream):
'''
......
......@@ -14,12 +14,19 @@
import paddle
from paddle.fluid import core
from paddle.utils import deprecated
__all__ = [
'synchronize',
]
@deprecated(
since="2.5.0",
update_to="paddle.device.synchronize",
level=1,
reason="synchronize in paddle.device.xpu will be removed in future",
)
def synchronize(device=None):
'''
Wait for the compute on the given XPU device to finish.
......
......@@ -314,6 +314,13 @@ try:
from .libpaddle import _is_fwd_prim_enabled
from .libpaddle import __set_all_prim_enabled
# custom devivce
from .libpaddle import _get_current_custom_device_stream
from .libpaddle import _set_current_custom_device_stream
from .libpaddle import _synchronize_custom_device
from .libpaddle import CustomDeviceStream
from .libpaddle import CustomDeviceEvent
if sys.platform != 'win32':
from .libpaddle import _set_process_pids
from .libpaddle import _erase_process_pids
......
......@@ -56,6 +56,7 @@ class TestCustomCPUPlugin(unittest.TestCase):
self._test_eager_copy_to()
self._test_fallback_kernel()
self._test_scalar()
self._test_custom_device_py_api()
def _test_custom_device_dataloader(self):
import paddle
......@@ -257,6 +258,34 @@ class TestCustomCPUPlugin(unittest.TestCase):
avg_loss.backward()
sgd.step()
def _test_custom_device_py_api(self):
import paddle
p = paddle.set_device('custom_cpu')
paddle.device.synchronize('custom_cpu')
s1 = paddle.device.Stream()
s2 = paddle.device.Stream(p)
s1 = paddle.device.current_stream()
s2 = paddle.device.current_stream(p)
e1 = paddle.device.Event()
e2 = paddle.device.Event(p)
s = paddle.device.Stream()
e = paddle.device.Event()
s.query()
s.synchronize()
s.wait_event(e)
s.record_event(e)
s.wait_stream(s)
paddle.device.set_stream(s)
e.query()
e.synchronize()
e.record(s)
if __name__ == '__main__':
if os.name == 'nt' or sys.platform.startswith('darwin'):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册