未验证 提交 2305089f 编写于 作者: H Huang Jiyi 提交者: GitHub

[phi decopuling] decouple dependency to device_context in phi (Part 2) (#51541)

* platform::CUDAPinnedDeviceContext -> phi::GPUPinnedContext

* replace platform::TraceEventCollector
上级 bc3afd82
...@@ -193,23 +193,5 @@ const Place& NPUPinnedDeviceContext::GetPlace() const { return place_; } ...@@ -193,23 +193,5 @@ const Place& NPUPinnedDeviceContext::GetPlace() const { return place_; }
#endif #endif
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
CUDAPinnedDeviceContext::CUDAPinnedDeviceContext() {
eigen_device_.reset(new Eigen::DefaultDevice());
}
CUDAPinnedDeviceContext::CUDAPinnedDeviceContext(CUDAPinnedPlace place)
: place_(place) {
eigen_device_.reset(new Eigen::DefaultDevice());
}
Eigen::DefaultDevice* CUDAPinnedDeviceContext::eigen_device() const {
return eigen_device_.get();
}
const Place& CUDAPinnedDeviceContext::GetPlace() const { return place_; }
#endif
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -234,24 +234,7 @@ class NPUPinnedDeviceContext ...@@ -234,24 +234,7 @@ class NPUPinnedDeviceContext
#endif #endif
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
// Currently, CUDAPinnedDeviceContext is only used to data copying. using CUDAPinnedDeviceContext = phi::GPUPinnedContext;
class CUDAPinnedDeviceContext
: public DeviceContext,
public phi::TypeInfoTraits<DeviceContext, CUDAPinnedDeviceContext> {
public:
CUDAPinnedDeviceContext();
explicit CUDAPinnedDeviceContext(CUDAPinnedPlace place);
const Place& GetPlace() const override;
Eigen::DefaultDevice* eigen_device() const;
static const char* name() { return "CUDAPinnedDeviceContext"; }
private:
CUDAPinnedPlace place_;
std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
};
#endif #endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE #ifdef PADDLE_WITH_CUSTOM_DEVICE
......
...@@ -548,15 +548,15 @@ Meanwhile, we need to simplify the writing method of Kernel registration. The ex ...@@ -548,15 +548,15 @@ Meanwhile, we need to simplify the writing method of Kernel registration. The ex
```c++ ```c++
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
scale, ops::ScaleKernel<paddle::platform::CPUDeviceContext, float>, scale, ops::ScaleKernel<phi::CPUContext, float>,
ops::ScaleKernel<paddle::platform::CPUDeviceContext, double>, ops::ScaleKernel<phi::CPUContext, double>,
ops::ScaleKernel<paddle::platform::CPUDeviceContext, ops::ScaleKernel<phi::CPUContext,
paddle::platform::bfloat16>, phi::dtype::bfloat16>,
ops::ScaleKernel<paddle::platform::CPUDeviceContext, uint8_t>, ops::ScaleKernel<phi::CPUContext, uint8_t>,
ops::ScaleKernel<paddle::platform::CPUDeviceContext, int8_t>, ops::ScaleKernel<phi::CPUContext, int8_t>,
ops::ScaleKernel<paddle::platform::CPUDeviceContext, int16_t>, ops::ScaleKernel<phi::CPUContext, int16_t>,
ops::ScaleKernel<paddle::platform::CPUDeviceContext, int>, ops::ScaleKernel<phi::CPUContext, int>,
ops::ScaleKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::ScaleKernel<phi::CPUContext, int64_t>);
``` ```
2. Paddle-Lite's kernel registration method declares input and output information for each Kernel, but since the kernel of each data type is different, it will also cause redundancy in the writing method. As you can see in the following code, except for the data type, other information is basically redundant. 2. Paddle-Lite's kernel registration method declares input and output information for each Kernel, but since the kernel of each data type is different, it will also cause redundancy in the writing method. As you can see in the following code, except for the data type, other information is basically redundant.
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/platform/profiler/trace_event_collector.h" #include "paddle/phi/api/profiler/trace_event_collector.h"
#include "paddle/phi/backends/callback_manager.h" #include "paddle/phi/backends/callback_manager.h"
#include "paddle/phi/backends/context_pool.h" #include "paddle/phi/backends/context_pool.h"
#include "paddle/phi/backends/custom/enforce_custom.h" #include "paddle/phi/backends/custom/enforce_custom.h"
...@@ -822,43 +822,42 @@ class CustomDevice : public DeviceInterface { ...@@ -822,43 +822,42 @@ class CustomDevice : public DeviceInterface {
} }
// Profiler // Profiler
void ProfilerInitialize(paddle::platform::TraceEventCollector* collector, void ProfilerInitialize(phi::TraceEventCollector* collector,
void** user_data) override { void** user_data) override {
CHECK_PTR(pimpl_->profiler_initialize); CHECK_PTR(pimpl_->profiler_initialize);
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->profiler_initialize( PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->profiler_initialize(
reinterpret_cast<C_Profiler>(collector), user_data)); reinterpret_cast<C_Profiler>(collector), user_data));
} }
void ProfilerFinalize(paddle::platform::TraceEventCollector* collector, void ProfilerFinalize(phi::TraceEventCollector* collector,
void* user_data) override { void* user_data) override {
CHECK_PTR(pimpl_->profiler_finalize); CHECK_PTR(pimpl_->profiler_finalize);
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->profiler_finalize( PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->profiler_finalize(
reinterpret_cast<C_Profiler>(collector), user_data)); reinterpret_cast<C_Profiler>(collector), user_data));
} }
void ProfilerPrepareTracing(paddle::platform::TraceEventCollector* collector, void ProfilerPrepareTracing(phi::TraceEventCollector* collector,
void* user_data) override { void* user_data) override {
CHECK_PTR(pimpl_->profiler_prepare_tracing); CHECK_PTR(pimpl_->profiler_prepare_tracing);
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->profiler_prepare_tracing( PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->profiler_prepare_tracing(
reinterpret_cast<C_Profiler>(collector), user_data)); reinterpret_cast<C_Profiler>(collector), user_data));
} }
void ProfilerStartTracing(paddle::platform::TraceEventCollector* collector, void ProfilerStartTracing(phi::TraceEventCollector* collector,
void* user_data) override { void* user_data) override {
CHECK_PTR(pimpl_->profiler_start_tracing); CHECK_PTR(pimpl_->profiler_start_tracing);
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->profiler_start_tracing( PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->profiler_start_tracing(
reinterpret_cast<C_Profiler>(collector), user_data)); reinterpret_cast<C_Profiler>(collector), user_data));
} }
void ProfilerStopTracing(paddle::platform::TraceEventCollector* collector, void ProfilerStopTracing(phi::TraceEventCollector* collector,
void* user_data) override { void* user_data) override {
CHECK_PTR(pimpl_->profiler_stop_tracing); CHECK_PTR(pimpl_->profiler_stop_tracing);
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->profiler_stop_tracing( PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->profiler_stop_tracing(
reinterpret_cast<C_Profiler>(collector), user_data)); reinterpret_cast<C_Profiler>(collector), user_data));
} }
void ProfilerCollectTraceData( void ProfilerCollectTraceData(phi::TraceEventCollector* collector,
paddle::platform::TraceEventCollector* collector,
uint64_t start_ns, uint64_t start_ns,
void* user_data) override { void* user_data) override {
CHECK_PTR(pimpl_->profiler_collect_trace_data); CHECK_PTR(pimpl_->profiler_collect_trace_data);
......
...@@ -369,35 +369,33 @@ void DeviceInterface::BlasAXPBY(size_t dev_id, ...@@ -369,35 +369,33 @@ void DeviceInterface::BlasAXPBY(size_t dev_id,
} }
// profiler // profiler
void DeviceInterface::ProfilerInitialize( void DeviceInterface::ProfilerInitialize(phi::TraceEventCollector* collector,
paddle::platform::TraceEventCollector* collector, void** user_data) { void** user_data) {
INTERFACE_UNIMPLEMENT; INTERFACE_UNIMPLEMENT;
} }
void DeviceInterface::ProfilerFinalize( void DeviceInterface::ProfilerFinalize(phi::TraceEventCollector* collector,
paddle::platform::TraceEventCollector* collector, void* user_data) { void* user_data) {
INTERFACE_UNIMPLEMENT; INTERFACE_UNIMPLEMENT;
} }
void DeviceInterface::ProfilerPrepareTracing( void DeviceInterface::ProfilerPrepareTracing(
paddle::platform::TraceEventCollector* collector, void* user_data) { phi::TraceEventCollector* collector, void* user_data) {
INTERFACE_UNIMPLEMENT; INTERFACE_UNIMPLEMENT;
} }
void DeviceInterface::ProfilerStartTracing( void DeviceInterface::ProfilerStartTracing(phi::TraceEventCollector* collector,
paddle::platform::TraceEventCollector* collector, void* user_data) { void* user_data) {
INTERFACE_UNIMPLEMENT; INTERFACE_UNIMPLEMENT;
} }
void DeviceInterface::ProfilerStopTracing( void DeviceInterface::ProfilerStopTracing(phi::TraceEventCollector* collector,
paddle::platform::TraceEventCollector* collector, void* user_data) { void* user_data) {
INTERFACE_UNIMPLEMENT; INTERFACE_UNIMPLEMENT;
} }
void DeviceInterface::ProfilerCollectTraceData( void DeviceInterface::ProfilerCollectTraceData(
paddle::platform::TraceEventCollector* collector, phi::TraceEventCollector* collector, uint64_t start_ns, void* user_data) {
uint64_t start_ns,
void* user_data) {
INTERFACE_UNIMPLEMENT; INTERFACE_UNIMPLEMENT;
} }
......
...@@ -19,6 +19,8 @@ ...@@ -19,6 +19,8 @@
#include "paddle/phi/backends/event.h" #include "paddle/phi/backends/event.h"
#include "paddle/phi/backends/stream.h" #include "paddle/phi/backends/stream.h"
#include "paddle/phi/api/profiler/trace_event_collector.h"
namespace paddle { namespace paddle {
namespace platform { namespace platform {
class TraceEventCollector; class TraceEventCollector;
...@@ -243,23 +245,22 @@ class DeviceInterface { // Driver / Runtime ...@@ -243,23 +245,22 @@ class DeviceInterface { // Driver / Runtime
void* y); void* y);
// profiler // profiler
virtual void ProfilerInitialize( virtual void ProfilerInitialize(phi::TraceEventCollector* collector,
paddle::platform::TraceEventCollector* collector, void** user_data); void** user_data);
virtual void ProfilerFinalize( virtual void ProfilerFinalize(phi::TraceEventCollector* collector,
paddle::platform::TraceEventCollector* collector, void* user_data); void* user_data);
virtual void ProfilerPrepareTracing( virtual void ProfilerPrepareTracing(phi::TraceEventCollector* collector,
paddle::platform::TraceEventCollector* collector, void* user_data); void* user_data);
virtual void ProfilerStartTracing( virtual void ProfilerStartTracing(phi::TraceEventCollector* collector,
paddle::platform::TraceEventCollector* collector, void* user_data); void* user_data);
virtual void ProfilerStopTracing( virtual void ProfilerStopTracing(phi::TraceEventCollector* collector,
paddle::platform::TraceEventCollector* collector, void* user_data); void* user_data);
virtual void ProfilerCollectTraceData( virtual void ProfilerCollectTraceData(phi::TraceEventCollector* collector,
paddle::platform::TraceEventCollector* collector,
uint64_t start_ns, uint64_t start_ns,
void* user_data); void* user_data);
......
...@@ -597,41 +597,36 @@ void DeviceManager::CCLRecv(const std::string& device_type, ...@@ -597,41 +597,36 @@ void DeviceManager::CCLRecv(const std::string& device_type,
} }
// profiler // profiler
void DeviceManager::ProfilerInitialize( void DeviceManager::ProfilerInitialize(const std::string& dev_type,
const std::string& dev_type, phi::TraceEventCollector* collector,
paddle::platform::TraceEventCollector* collector,
void** context) { void** context) {
auto dev_impl = GetDeviceInterfaceWithType(dev_type); auto dev_impl = GetDeviceInterfaceWithType(dev_type);
dev_impl->ProfilerInitialize(collector, context); dev_impl->ProfilerInitialize(collector, context);
} }
void DeviceManager::ProfilerFinalize( void DeviceManager::ProfilerFinalize(const std::string& dev_type,
const std::string& dev_type, phi::TraceEventCollector* collector,
paddle::platform::TraceEventCollector* collector,
void* context) { void* context) {
auto dev_impl = GetDeviceInterfaceWithType(dev_type); auto dev_impl = GetDeviceInterfaceWithType(dev_type);
dev_impl->ProfilerFinalize(collector, context); dev_impl->ProfilerFinalize(collector, context);
} }
void DeviceManager::ProfilerPrepareTracing( void DeviceManager::ProfilerPrepareTracing(const std::string& dev_type,
const std::string& dev_type, phi::TraceEventCollector* collector,
paddle::platform::TraceEventCollector* collector,
void* context) { void* context) {
auto dev_impl = GetDeviceInterfaceWithType(dev_type); auto dev_impl = GetDeviceInterfaceWithType(dev_type);
dev_impl->ProfilerPrepareTracing(collector, context); dev_impl->ProfilerPrepareTracing(collector, context);
} }
void DeviceManager::ProfilerStartTracing( void DeviceManager::ProfilerStartTracing(const std::string& dev_type,
const std::string& dev_type, phi::TraceEventCollector* collector,
paddle::platform::TraceEventCollector* collector,
void* context) { void* context) {
auto dev_impl = GetDeviceInterfaceWithType(dev_type); auto dev_impl = GetDeviceInterfaceWithType(dev_type);
dev_impl->ProfilerStartTracing(collector, context); dev_impl->ProfilerStartTracing(collector, context);
} }
void DeviceManager::ProfilerStopTracing( void DeviceManager::ProfilerStopTracing(const std::string& dev_type,
const std::string& dev_type, phi::TraceEventCollector* collector,
paddle::platform::TraceEventCollector* collector,
void* context) { void* context) {
auto dev_impl = GetDeviceInterfaceWithType(dev_type); auto dev_impl = GetDeviceInterfaceWithType(dev_type);
dev_impl->ProfilerStopTracing(collector, context); dev_impl->ProfilerStopTracing(collector, context);
...@@ -639,7 +634,7 @@ void DeviceManager::ProfilerStopTracing( ...@@ -639,7 +634,7 @@ void DeviceManager::ProfilerStopTracing(
void DeviceManager::ProfilerCollectTraceData( void DeviceManager::ProfilerCollectTraceData(
const std::string& dev_type, const std::string& dev_type,
paddle::platform::TraceEventCollector* collector, phi::TraceEventCollector* collector,
uint64_t start_ns, uint64_t start_ns,
void* context) { void* context) {
auto dev_impl = GetDeviceInterfaceWithType(dev_type); auto dev_impl = GetDeviceInterfaceWithType(dev_type);
......
...@@ -242,28 +242,23 @@ class DeviceManager { ...@@ -242,28 +242,23 @@ class DeviceManager {
const stream::Stream& stream); const stream::Stream& stream);
// profiler // profiler
static void ProfilerInitialize( static void ProfilerInitialize(const std::string& dev_type,
const std::string& dev_type, phi::TraceEventCollector* collector,
paddle::platform::TraceEventCollector* collector,
void** context); void** context);
static void ProfilerFinalize(const std::string& dev_type, static void ProfilerFinalize(const std::string& dev_type,
paddle::platform::TraceEventCollector* collector, phi::TraceEventCollector* collector,
void* context); void* context);
static void ProfilerPrepareTracing( static void ProfilerPrepareTracing(const std::string& dev_type,
const std::string& dev_type, phi::TraceEventCollector* collector,
paddle::platform::TraceEventCollector* collector,
void* context); void* context);
static void ProfilerStartTracing( static void ProfilerStartTracing(const std::string& dev_type,
const std::string& dev_type, phi::TraceEventCollector* collector,
paddle::platform::TraceEventCollector* collector,
void* context); void* context);
static void ProfilerStopTracing( static void ProfilerStopTracing(const std::string& dev_type,
const std::string& dev_type, phi::TraceEventCollector* collector,
paddle::platform::TraceEventCollector* collector,
void* context); void* context);
static void ProfilerCollectTraceData( static void ProfilerCollectTraceData(const std::string& dev_type,
const std::string& dev_type, phi::TraceEventCollector* collector,
paddle::platform::TraceEventCollector* collector,
uint64_t start_ns, uint64_t start_ns,
void* context); void* context);
......
...@@ -1046,4 +1046,20 @@ void GPUContext::SetDnnAttr(const std::string& attr_name, Attribute attr) { ...@@ -1046,4 +1046,20 @@ void GPUContext::SetDnnAttr(const std::string& attr_name, Attribute attr) {
void GPUContext::ClearDnnAttr() { return impl_->ClearDnnAttr(); } void GPUContext::ClearDnnAttr() { return impl_->ClearDnnAttr(); }
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
GPUPinnedContext::GPUPinnedContext() {
eigen_device_.reset(new Eigen::DefaultDevice());
}
GPUPinnedContext::GPUPinnedContext(GPUPinnedPlace place) : place_(place) {
eigen_device_.reset(new Eigen::DefaultDevice());
}
Eigen::DefaultDevice* GPUPinnedContext::eigen_device() const {
return eigen_device_.get();
}
const Place& GPUPinnedContext::GetPlace() const { return place_; }
#endif
} // namespace phi } // namespace phi
...@@ -278,3 +278,30 @@ using KPSContext = GPUContext; ...@@ -278,3 +278,30 @@ using KPSContext = GPUContext;
#endif #endif
} // namespace phi } // namespace phi
namespace Eigen {
struct DefaultDevice;
} // namespace Eigen
namespace phi {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
// Currently, GPUPinnedContext is only used to data copying.
class GPUPinnedContext
: public DeviceContext,
public phi::TypeInfoTraits<DeviceContext, GPUPinnedContext> {
public:
GPUPinnedContext();
explicit GPUPinnedContext(GPUPinnedPlace place);
const Place& GetPlace() const override;
Eigen::DefaultDevice* eigen_device() const;
static const char* name() { return "GPUPinnedContext"; }
private:
GPUPinnedPlace place_;
std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
};
#endif
} // namespace phi
...@@ -114,20 +114,17 @@ template struct SetConstant<phi::GPUContext, bool>; ...@@ -114,20 +114,17 @@ template struct SetConstant<phi::GPUContext, bool>;
template struct SetConstant<phi::GPUContext, phi::dtype::complex<float>>; template struct SetConstant<phi::GPUContext, phi::dtype::complex<float>>;
template struct SetConstant<phi::GPUContext, phi::dtype::complex<double>>; template struct SetConstant<phi::GPUContext, phi::dtype::complex<double>>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, float16>; template struct SetConstant<phi::GPUPinnedContext, float16>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, template struct SetConstant<phi::GPUPinnedContext, bfloat16>;
bfloat16>; template struct SetConstant<phi::GPUPinnedContext, float>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, float>; template struct SetConstant<phi::GPUPinnedContext, double>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, double>; template struct SetConstant<phi::GPUPinnedContext, uint8_t>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, uint8_t>; template struct SetConstant<phi::GPUPinnedContext, int>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, int>; template struct SetConstant<phi::GPUPinnedContext, int16_t>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, int16_t>; template struct SetConstant<phi::GPUPinnedContext, int64_t>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, int64_t>; template struct SetConstant<phi::GPUPinnedContext, bool>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, bool>; template struct SetConstant<phi::GPUPinnedContext, phi::dtype::complex<float>>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, template struct SetConstant<phi::GPUPinnedContext, phi::dtype::complex<double>>;
phi::dtype::complex<float>>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext,
phi::dtype::complex<double>>;
#define DEFINE_GPU_TRANS(RANK) \ #define DEFINE_GPU_TRANS(RANK) \
template struct Transpose<phi::GPUContext, bool, RANK>; \ template struct Transpose<phi::GPUContext, bool, RANK>; \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册