未验证 提交 047ee26c 编写于 作者: C Chen Weihang 提交者: GitHub

[PTen] Unify device context entrance in pten part 1 (#38172)

* unify device context entrance

* move all_context include to header

* polish cmake relay for device_context

* fix npu compile failed

* fix npu compile failed

* revert part of change
上级 4ef59f08
...@@ -23,7 +23,7 @@ add_subdirectory(ops) ...@@ -23,7 +23,7 @@ add_subdirectory(ops)
add_subdirectory(tests) add_subdirectory(tests)
# make an unity target for compile deps # make an unity target for compile deps
set(PTEN_DEPS convert_utils dense_tensor kernel_factory kernel_context) set(PTEN_DEPS convert_utils dense_tensor pten_context kernel_factory kernel_context)
set(PTEN_DEPS ${PTEN_DEPS} math_cpu linalg_cpu creation_cpu manipulation_cpu) set(PTEN_DEPS ${PTEN_DEPS} math_cpu linalg_cpu creation_cpu manipulation_cpu)
set(PTEN_DEPS ${PTEN_DEPS} nary unary binary) set(PTEN_DEPS ${PTEN_DEPS} nary unary binary)
if(WITH_GPU OR WITH_ROCM) if(WITH_GPU OR WITH_ROCM)
......
...@@ -10,7 +10,7 @@ else() ...@@ -10,7 +10,7 @@ else()
cc_library(pten_tensor SRCS tensor.cc DEPS tensor_base dense_tensor pten_api_utils ext_compat_utils enforce) cc_library(pten_tensor SRCS tensor.cc DEPS tensor_base dense_tensor pten_api_utils ext_compat_utils enforce)
endif() endif()
cc_library(kernel_dispatch SRCS kernel_dispatch.cc DEPS pten_tensor device_context kernel_factory) cc_library(kernel_dispatch SRCS kernel_dispatch.cc DEPS pten_tensor pten_context kernel_factory)
cc_library(op_meta_info SRCS op_meta_info.cc DEPS pten_tensor) cc_library(op_meta_info SRCS op_meta_info.cc DEPS pten_tensor)
......
...@@ -51,8 +51,7 @@ std::size_t CountLeadingZeros(uint64_t val) { ...@@ -51,8 +51,7 @@ std::size_t CountLeadingZeros(uint64_t val) {
} // namespace detail } // namespace detail
paddle::platform::DeviceContext* GetDeviceContextByBackend( pten::DeviceContext* GetDeviceContextByBackend(pten::Backend backend) {
pten::Backend backend) {
auto& pool = paddle::platform::DeviceContextPool::Instance(); auto& pool = paddle::platform::DeviceContextPool::Instance();
return pool.Get(pten::TransToFluidPlace(backend)); return pool.Get(pten::TransToFluidPlace(backend));
} }
......
...@@ -21,31 +21,22 @@ limitations under the License. */ ...@@ -21,31 +21,22 @@ limitations under the License. */
#include "paddle/pten/api/include/tensor.h" #include "paddle/pten/api/include/tensor.h"
#include "paddle/pten/api/lib/backend_set.h" #include "paddle/pten/api/lib/backend_set.h"
#include "paddle/pten/api/lib/data_type_set.h" #include "paddle/pten/api/lib/data_type_set.h"
#include "paddle/pten/backends/all_context.h"
#include "paddle/pten/common/data_type.h" #include "paddle/pten/common/data_type.h"
#include "paddle/pten/common/layout.h" #include "paddle/pten/common/layout.h"
// TODO(chenweihang): split Key, Kernel, Factory into diff files // TODO(chenweihang): split Key, Kernel, Factory into diff files
#include "paddle/pten/core/kernel_factory.h" #include "paddle/pten/core/kernel_factory.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/device_context.h"
namespace paddle { namespace paddle {
namespace experimental { namespace experimental {
// TODO(shixiaowei): replaced by new DeviceContext later
using CPUContext = paddle::platform::CPUDeviceContext;
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
using CUDAContext = paddle::platform::CUDADeviceContext;
#endif
namespace detail { namespace detail {
BackendSet GetTensorBackendSet(const Tensor& t); BackendSet GetTensorBackendSet(const Tensor& t);
std::size_t CountLeadingZeros(uint64_t val); std::size_t CountLeadingZeros(uint64_t val);
} // namespace detail } // namespace detail
paddle::platform::DeviceContext* GetDeviceContextByBackend( pten::DeviceContext* GetDeviceContextByBackend(pten::Backend backend);
pten::Backend backend);
// TODO(chenweihang): support DataLayout and DataType selected // TODO(chenweihang): support DataLayout and DataType selected
struct KernelKeySet { struct KernelKeySet {
......
cc_library(pten_context SRCS all_context.cc DEPS device_context)
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/backends/all_context.h"
namespace pten {} // namespace pten
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
// Note: Some scenarios need to include all types of Context declarations.
// In order to avoid including the header files of each backend in turn,
// add this header file
// Note: Limit the entry of DeviceContext to backends to avoid multiple include
// path replacement after implementing pten DeviceContext
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/backends/cuda/cuda_context.h"
#include "paddle/pten/backends/npu/npu_context.h"
#include "paddle/pten/backends/xpu/xpu_context.h"
namespace pten {
using DeviceContext = paddle::platform::DeviceContext;
using DeviceContextPool = paddle::platform::DeviceContextPool;
} // namespace pten
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/device_context.h"
namespace pten {
using CPUContext = paddle::platform::CPUDeviceContext;
} // namespace pten
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/device_context.h"
namespace pten {
using CUDAContext = paddle::platform::CUDADeviceContext;
} // namespace pten
#endif
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef PADDLE_WITH_ASCEND_CL
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/device_context.h"
namespace pten {
using NPUContext = paddle::platform::NPUDeviceContext;
} // namespace pten
#endif // PADDLE_WITH_ASCEND_CL
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef PADDLE_WITH_XPU
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/device_context.h"
namespace pten {
using XPUContext = paddle::platform::XPUDeviceContext;
} // namespace pten
#endif // PADDLE_WITH_XPU
...@@ -7,7 +7,7 @@ else() ...@@ -7,7 +7,7 @@ else()
endif() endif()
cc_library(kernel_factory SRCS kernel_factory.cc DEPS enforce convert_utils) cc_library(kernel_factory SRCS kernel_factory.cc DEPS enforce convert_utils)
cc_library(kernel_context SRCS kernel_context.cc DEPS enforce device_context) cc_library(kernel_context SRCS kernel_context.cc DEPS enforce pten_context)
cc_library(tensor_base SRCS tensor_base.cc allocator.cc storage.cc DEPS enforce) cc_library(tensor_base SRCS tensor_base.cc allocator.cc storage.cc DEPS enforce)
cc_library(tensor_meta SRCS tensor_meta.cc DEPS enforce) cc_library(tensor_meta SRCS tensor_meta.cc DEPS enforce)
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include "paddle/pten/backends/all_context.h"
#include "paddle/pten/common/scalar.h" #include "paddle/pten/common/scalar.h"
#include "paddle/pten/common/scalar_array.h" #include "paddle/pten/common/scalar_array.h"
#include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/core/dense_tensor.h"
...@@ -21,26 +22,10 @@ ...@@ -21,26 +22,10 @@
#include "paddle/pten/core/kernel_def.h" #include "paddle/pten/core/kernel_def.h"
// See Note [ Why still include the fluid headers? ] // See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
namespace pten { namespace pten {
// TODO(shixiaowei): replaced by new DeviceContext later
using CPUContext = paddle::platform::CPUDeviceContext;
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
using CUDAContext = paddle::platform::CUDADeviceContext;
#endif
#ifdef PADDLE_WITH_MKLDNN
using MKLDNNContext = paddle::platform::MKLDNNDeviceContext;
#endif
#ifdef PADDLE_WITH_ASCEND_CL
using NPUContext = paddle::platform::NPUDeviceContext;
#endif
#ifdef PADDLE_WITH_XPU
using XPUContext = paddle::platform::XPUDeviceContext;
#endif
#define PT_KERNEL(...) \ #define PT_KERNEL(...) \
::pten::KernelImpl<decltype(&__VA_ARGS__), &__VA_ARGS__>::Compute ::pten::KernelImpl<decltype(&__VA_ARGS__), &__VA_ARGS__>::Compute
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册