未验证 提交 2305089f 编写于 作者: H Huang Jiyi 提交者: GitHub

[phi decopuling] decouple dependency to device_context in phi (Part 2) (#51541)

* platform::CUDAPinnedDeviceContext -> phi::GPUPinnedContext

* replace platform::TraceEventCollector
上级 bc3afd82
......@@ -193,23 +193,5 @@ const Place& NPUPinnedDeviceContext::GetPlace() const { return place_; }
#endif
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
CUDAPinnedDeviceContext::CUDAPinnedDeviceContext() {
eigen_device_.reset(new Eigen::DefaultDevice());
}
CUDAPinnedDeviceContext::CUDAPinnedDeviceContext(CUDAPinnedPlace place)
: place_(place) {
eigen_device_.reset(new Eigen::DefaultDevice());
}
Eigen::DefaultDevice* CUDAPinnedDeviceContext::eigen_device() const {
return eigen_device_.get();
}
const Place& CUDAPinnedDeviceContext::GetPlace() const { return place_; }
#endif
} // namespace platform
} // namespace paddle
......@@ -234,24 +234,7 @@ class NPUPinnedDeviceContext
#endif
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
// Currently, CUDAPinnedDeviceContext is only used to data copying.
class CUDAPinnedDeviceContext
: public DeviceContext,
public phi::TypeInfoTraits<DeviceContext, CUDAPinnedDeviceContext> {
public:
CUDAPinnedDeviceContext();
explicit CUDAPinnedDeviceContext(CUDAPinnedPlace place);
const Place& GetPlace() const override;
Eigen::DefaultDevice* eigen_device() const;
static const char* name() { return "CUDAPinnedDeviceContext"; }
private:
CUDAPinnedPlace place_;
std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
};
using CUDAPinnedDeviceContext = phi::GPUPinnedContext;
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
......
......@@ -548,15 +548,15 @@ Meanwhile, we need to simplify the writing method of Kernel registration. The ex
```c++
REGISTER_OP_CPU_KERNEL(
scale, ops::ScaleKernel<paddle::platform::CPUDeviceContext, float>,
ops::ScaleKernel<paddle::platform::CPUDeviceContext, double>,
ops::ScaleKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>,
ops::ScaleKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::ScaleKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::ScaleKernel<paddle::platform::CPUDeviceContext, int16_t>,
ops::ScaleKernel<paddle::platform::CPUDeviceContext, int>,
ops::ScaleKernel<paddle::platform::CPUDeviceContext, int64_t>);
scale, ops::ScaleKernel<phi::CPUContext, float>,
ops::ScaleKernel<phi::CPUContext, double>,
ops::ScaleKernel<phi::CPUContext,
phi::dtype::bfloat16>,
ops::ScaleKernel<phi::CPUContext, uint8_t>,
ops::ScaleKernel<phi::CPUContext, int8_t>,
ops::ScaleKernel<phi::CPUContext, int16_t>,
ops::ScaleKernel<phi::CPUContext, int>,
ops::ScaleKernel<phi::CPUContext, int64_t>);
```
2. Paddle-Lite's kernel registration method declares input and output information for each Kernel, but since the kernel of each data type is different, it will also cause redundancy in the writing method. As you can see in the following code, except for the data type, other information is basically redundant.
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/profiler/trace_event_collector.h"
#include "paddle/phi/api/profiler/trace_event_collector.h"
#include "paddle/phi/backends/callback_manager.h"
#include "paddle/phi/backends/context_pool.h"
#include "paddle/phi/backends/custom/enforce_custom.h"
......@@ -822,45 +822,44 @@ class CustomDevice : public DeviceInterface {
}
// Profiler
void ProfilerInitialize(paddle::platform::TraceEventCollector* collector,
void ProfilerInitialize(phi::TraceEventCollector* collector,
void** user_data) override {
CHECK_PTR(pimpl_->profiler_initialize);
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->profiler_initialize(
reinterpret_cast<C_Profiler>(collector), user_data));
}
void ProfilerFinalize(paddle::platform::TraceEventCollector* collector,
void ProfilerFinalize(phi::TraceEventCollector* collector,
void* user_data) override {
CHECK_PTR(pimpl_->profiler_finalize);
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->profiler_finalize(
reinterpret_cast<C_Profiler>(collector), user_data));
}
void ProfilerPrepareTracing(paddle::platform::TraceEventCollector* collector,
void ProfilerPrepareTracing(phi::TraceEventCollector* collector,
void* user_data) override {
CHECK_PTR(pimpl_->profiler_prepare_tracing);
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->profiler_prepare_tracing(
reinterpret_cast<C_Profiler>(collector), user_data));
}
void ProfilerStartTracing(paddle::platform::TraceEventCollector* collector,
void ProfilerStartTracing(phi::TraceEventCollector* collector,
void* user_data) override {
CHECK_PTR(pimpl_->profiler_start_tracing);
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->profiler_start_tracing(
reinterpret_cast<C_Profiler>(collector), user_data));
}
void ProfilerStopTracing(paddle::platform::TraceEventCollector* collector,
void ProfilerStopTracing(phi::TraceEventCollector* collector,
void* user_data) override {
CHECK_PTR(pimpl_->profiler_stop_tracing);
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->profiler_stop_tracing(
reinterpret_cast<C_Profiler>(collector), user_data));
}
void ProfilerCollectTraceData(
paddle::platform::TraceEventCollector* collector,
uint64_t start_ns,
void* user_data) override {
void ProfilerCollectTraceData(phi::TraceEventCollector* collector,
uint64_t start_ns,
void* user_data) override {
CHECK_PTR(pimpl_->profiler_collect_trace_data);
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->profiler_collect_trace_data(
reinterpret_cast<C_Profiler>(collector), start_ns, user_data));
......
......@@ -369,35 +369,33 @@ void DeviceInterface::BlasAXPBY(size_t dev_id,
}
// profiler
void DeviceInterface::ProfilerInitialize(
paddle::platform::TraceEventCollector* collector, void** user_data) {
void DeviceInterface::ProfilerInitialize(phi::TraceEventCollector* collector,
void** user_data) {
INTERFACE_UNIMPLEMENT;
}
void DeviceInterface::ProfilerFinalize(
paddle::platform::TraceEventCollector* collector, void* user_data) {
void DeviceInterface::ProfilerFinalize(phi::TraceEventCollector* collector,
void* user_data) {
INTERFACE_UNIMPLEMENT;
}
void DeviceInterface::ProfilerPrepareTracing(
paddle::platform::TraceEventCollector* collector, void* user_data) {
phi::TraceEventCollector* collector, void* user_data) {
INTERFACE_UNIMPLEMENT;
}
void DeviceInterface::ProfilerStartTracing(
paddle::platform::TraceEventCollector* collector, void* user_data) {
void DeviceInterface::ProfilerStartTracing(phi::TraceEventCollector* collector,
void* user_data) {
INTERFACE_UNIMPLEMENT;
}
void DeviceInterface::ProfilerStopTracing(
paddle::platform::TraceEventCollector* collector, void* user_data) {
void DeviceInterface::ProfilerStopTracing(phi::TraceEventCollector* collector,
void* user_data) {
INTERFACE_UNIMPLEMENT;
}
void DeviceInterface::ProfilerCollectTraceData(
paddle::platform::TraceEventCollector* collector,
uint64_t start_ns,
void* user_data) {
phi::TraceEventCollector* collector, uint64_t start_ns, void* user_data) {
INTERFACE_UNIMPLEMENT;
}
......
......@@ -19,6 +19,8 @@
#include "paddle/phi/backends/event.h"
#include "paddle/phi/backends/stream.h"
#include "paddle/phi/api/profiler/trace_event_collector.h"
namespace paddle {
namespace platform {
class TraceEventCollector;
......@@ -243,25 +245,24 @@ class DeviceInterface { // Driver / Runtime
void* y);
// profiler
virtual void ProfilerInitialize(
paddle::platform::TraceEventCollector* collector, void** user_data);
virtual void ProfilerInitialize(phi::TraceEventCollector* collector,
void** user_data);
virtual void ProfilerFinalize(
paddle::platform::TraceEventCollector* collector, void* user_data);
virtual void ProfilerFinalize(phi::TraceEventCollector* collector,
void* user_data);
virtual void ProfilerPrepareTracing(
paddle::platform::TraceEventCollector* collector, void* user_data);
virtual void ProfilerPrepareTracing(phi::TraceEventCollector* collector,
void* user_data);
virtual void ProfilerStartTracing(
paddle::platform::TraceEventCollector* collector, void* user_data);
virtual void ProfilerStartTracing(phi::TraceEventCollector* collector,
void* user_data);
virtual void ProfilerStopTracing(
paddle::platform::TraceEventCollector* collector, void* user_data);
virtual void ProfilerStopTracing(phi::TraceEventCollector* collector,
void* user_data);
virtual void ProfilerCollectTraceData(
paddle::platform::TraceEventCollector* collector,
uint64_t start_ns,
void* user_data);
virtual void ProfilerCollectTraceData(phi::TraceEventCollector* collector,
uint64_t start_ns,
void* user_data);
private:
const std::string type_;
......
......@@ -597,49 +597,44 @@ void DeviceManager::CCLRecv(const std::string& device_type,
}
// profiler
void DeviceManager::ProfilerInitialize(
const std::string& dev_type,
paddle::platform::TraceEventCollector* collector,
void** context) {
void DeviceManager::ProfilerInitialize(const std::string& dev_type,
phi::TraceEventCollector* collector,
void** context) {
auto dev_impl = GetDeviceInterfaceWithType(dev_type);
dev_impl->ProfilerInitialize(collector, context);
}
void DeviceManager::ProfilerFinalize(
const std::string& dev_type,
paddle::platform::TraceEventCollector* collector,
void* context) {
void DeviceManager::ProfilerFinalize(const std::string& dev_type,
phi::TraceEventCollector* collector,
void* context) {
auto dev_impl = GetDeviceInterfaceWithType(dev_type);
dev_impl->ProfilerFinalize(collector, context);
}
void DeviceManager::ProfilerPrepareTracing(
const std::string& dev_type,
paddle::platform::TraceEventCollector* collector,
void* context) {
void DeviceManager::ProfilerPrepareTracing(const std::string& dev_type,
phi::TraceEventCollector* collector,
void* context) {
auto dev_impl = GetDeviceInterfaceWithType(dev_type);
dev_impl->ProfilerPrepareTracing(collector, context);
}
void DeviceManager::ProfilerStartTracing(
const std::string& dev_type,
paddle::platform::TraceEventCollector* collector,
void* context) {
void DeviceManager::ProfilerStartTracing(const std::string& dev_type,
phi::TraceEventCollector* collector,
void* context) {
auto dev_impl = GetDeviceInterfaceWithType(dev_type);
dev_impl->ProfilerStartTracing(collector, context);
}
void DeviceManager::ProfilerStopTracing(
const std::string& dev_type,
paddle::platform::TraceEventCollector* collector,
void* context) {
void DeviceManager::ProfilerStopTracing(const std::string& dev_type,
phi::TraceEventCollector* collector,
void* context) {
auto dev_impl = GetDeviceInterfaceWithType(dev_type);
dev_impl->ProfilerStopTracing(collector, context);
}
void DeviceManager::ProfilerCollectTraceData(
const std::string& dev_type,
paddle::platform::TraceEventCollector* collector,
phi::TraceEventCollector* collector,
uint64_t start_ns,
void* context) {
auto dev_impl = GetDeviceInterfaceWithType(dev_type);
......
......@@ -242,30 +242,25 @@ class DeviceManager {
const stream::Stream& stream);
// profiler
static void ProfilerInitialize(
const std::string& dev_type,
paddle::platform::TraceEventCollector* collector,
void** context);
static void ProfilerInitialize(const std::string& dev_type,
phi::TraceEventCollector* collector,
void** context);
static void ProfilerFinalize(const std::string& dev_type,
paddle::platform::TraceEventCollector* collector,
phi::TraceEventCollector* collector,
void* context);
static void ProfilerPrepareTracing(
const std::string& dev_type,
paddle::platform::TraceEventCollector* collector,
void* context);
static void ProfilerStartTracing(
const std::string& dev_type,
paddle::platform::TraceEventCollector* collector,
void* context);
static void ProfilerStopTracing(
const std::string& dev_type,
paddle::platform::TraceEventCollector* collector,
void* context);
static void ProfilerCollectTraceData(
const std::string& dev_type,
paddle::platform::TraceEventCollector* collector,
uint64_t start_ns,
void* context);
static void ProfilerPrepareTracing(const std::string& dev_type,
phi::TraceEventCollector* collector,
void* context);
static void ProfilerStartTracing(const std::string& dev_type,
phi::TraceEventCollector* collector,
void* context);
static void ProfilerStopTracing(const std::string& dev_type,
phi::TraceEventCollector* collector,
void* context);
static void ProfilerCollectTraceData(const std::string& dev_type,
phi::TraceEventCollector* collector,
uint64_t start_ns,
void* context);
static void Clear();
......
......@@ -1046,4 +1046,20 @@ void GPUContext::SetDnnAttr(const std::string& attr_name, Attribute attr) {
void GPUContext::ClearDnnAttr() { return impl_->ClearDnnAttr(); }
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
GPUPinnedContext::GPUPinnedContext() {
eigen_device_.reset(new Eigen::DefaultDevice());
}
GPUPinnedContext::GPUPinnedContext(GPUPinnedPlace place) : place_(place) {
eigen_device_.reset(new Eigen::DefaultDevice());
}
Eigen::DefaultDevice* GPUPinnedContext::eigen_device() const {
return eigen_device_.get();
}
const Place& GPUPinnedContext::GetPlace() const { return place_; }
#endif
} // namespace phi
......@@ -278,3 +278,30 @@ using KPSContext = GPUContext;
#endif
} // namespace phi
namespace Eigen {
struct DefaultDevice;
} // namespace Eigen
namespace phi {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
// Currently, GPUPinnedContext is only used to data copying.
class GPUPinnedContext
: public DeviceContext,
public phi::TypeInfoTraits<DeviceContext, GPUPinnedContext> {
public:
GPUPinnedContext();
explicit GPUPinnedContext(GPUPinnedPlace place);
const Place& GetPlace() const override;
Eigen::DefaultDevice* eigen_device() const;
static const char* name() { return "GPUPinnedContext"; }
private:
GPUPinnedPlace place_;
std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
};
#endif
} // namespace phi
......@@ -114,20 +114,17 @@ template struct SetConstant<phi::GPUContext, bool>;
template struct SetConstant<phi::GPUContext, phi::dtype::complex<float>>;
template struct SetConstant<phi::GPUContext, phi::dtype::complex<double>>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, float16>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext,
bfloat16>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, float>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, double>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, uint8_t>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, int>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, int16_t>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, int64_t>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, bool>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext,
phi::dtype::complex<float>>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext,
phi::dtype::complex<double>>;
template struct SetConstant<phi::GPUPinnedContext, float16>;
template struct SetConstant<phi::GPUPinnedContext, bfloat16>;
template struct SetConstant<phi::GPUPinnedContext, float>;
template struct SetConstant<phi::GPUPinnedContext, double>;
template struct SetConstant<phi::GPUPinnedContext, uint8_t>;
template struct SetConstant<phi::GPUPinnedContext, int>;
template struct SetConstant<phi::GPUPinnedContext, int16_t>;
template struct SetConstant<phi::GPUPinnedContext, int64_t>;
template struct SetConstant<phi::GPUPinnedContext, bool>;
template struct SetConstant<phi::GPUPinnedContext, phi::dtype::complex<float>>;
template struct SetConstant<phi::GPUPinnedContext, phi::dtype::complex<double>>;
#define DEFINE_GPU_TRANS(RANK) \
template struct Transpose<phi::GPUContext, bool, RANK>; \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册