diff --git a/paddle/phi/api/include/context_pool.h b/paddle/phi/api/include/context_pool.h index f2e797575fd576c0473782eed103d077195c3542..f548696bf999941371d0ef4aa45498f93be08f3e 100644 --- a/paddle/phi/api/include/context_pool.h +++ b/paddle/phi/api/include/context_pool.h @@ -87,3 +87,10 @@ class PADDLE_API DeviceContextPool { } // namespace experimental } // namespace paddle + +namespace phi { +class Allocator; + +PADDLE_API Allocator* GetAllocator(const Place& place); + +} // namespace phi diff --git a/paddle/phi/api/lib/CMakeLists.txt b/paddle/phi/api/lib/CMakeLists.txt index bcd355ab6cfdce0e15ce0062e4b52c2b137eccb4..84a49436b6d1f1e0ca58685469b50dea0e38cc50 100644 --- a/paddle/phi/api/lib/CMakeLists.txt +++ b/paddle/phi/api/lib/CMakeLists.txt @@ -289,7 +289,7 @@ cc_library( cc_library( context_pool SRCS context_pool.cc - DEPS phi_backends phi_enforce place init) + DEPS phi_backends phi_enforce place init phi_device_context) cc_library( kernel_dispatch diff --git a/paddle/phi/api/lib/context_pool.cc b/paddle/phi/api/lib/context_pool.cc index f3b148fb7bc9ddb4b5b3a4b7b5ec6a464b254f9f..a17df04183e82f684cee3f244f46c6f6152da461 100644 --- a/paddle/phi/api/lib/context_pool.cc +++ b/paddle/phi/api/lib/context_pool.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include "paddle/phi/api/include/context_pool.h" #include "paddle/phi/backends/all_context.h" +#include "paddle/phi/core/allocator.h" #include "paddle/phi/core/enforce.h" #include "paddle/fluid/platform/init.h" @@ -50,3 +51,13 @@ phi::DeviceContext* DeviceContextPool::GetMutable(const Place& place) { } // namespace experimental } // namespace paddle + +namespace phi { + +PADDLE_API Allocator* GetAllocator(const Place& place) { + const DeviceContext* dev_ctx = + paddle::experimental::DeviceContextPool::Instance().Get(place); + return const_cast(&dev_ctx->GetAllocator()); +} + +} // namespace phi diff --git a/paddle/phi/core/allocator.cc b/paddle/phi/core/allocator.cc index 4d766d7003f6b8ed3d6119b75f271951d3056f32..76e5c38c51ae1655efbf66ae8d071bdecc248642 100644 --- a/paddle/phi/core/allocator.cc +++ b/paddle/phi/core/allocator.cc @@ -14,4 +14,7 @@ limitations under the License. */ #include "paddle/phi/core/allocator.h" +#include "paddle/phi/api/include/context_pool.h" +#include "paddle/phi/core/device_context.h" + namespace phi {} // namespace phi diff --git a/paddle/phi/core/allocator.h b/paddle/phi/core/allocator.h index 849fc1548c7ec91a78899777aefa8aa58d61b3df..9595a51ec0316ab4df7a738ba0579fe5f1473e25 100644 --- a/paddle/phi/core/allocator.h +++ b/paddle/phi/core/allocator.h @@ -18,6 +18,7 @@ limitations under the License. */ #include #include +#include "paddle/phi/api/include/dll_decl.h" #include "paddle/phi/common/place.h" namespace phi { diff --git a/paddle/phi/tests/api/CMakeLists.txt b/paddle/phi/tests/api/CMakeLists.txt index 50973aea52483a4207d69bf302cbbd07d172954b..7ee6bd692b7e50df3a6b622d8ec64cf108911a29 100644 --- a/paddle/phi/tests/api/CMakeLists.txt +++ b/paddle/phi/tests/api/CMakeLists.txt @@ -5,11 +5,19 @@ if(WITH_GPU) test_phi_tensor SRCS test_phi_tensor.cc DEPS glog selected_rows ${COMMON_API_TEST_DEPS}) + nv_test( + test_allocator + SRCS test_allocator.cu + DEPS memory place device_context context_pool) elseif(WITH_ROCM) hip_test( test_phi_tensor SRCS test_phi_tensor.cc DEPS glog selected_rows ${COMMON_API_TEST_DEPS}) + hip_test( + test_allocator + SRCS test_allocator.cu + DEPS memory place device_context context_pool) else() cc_test( test_phi_tensor diff --git a/paddle/phi/tests/api/test_allocator.cu b/paddle/phi/tests/api/test_allocator.cu new file mode 100644 index 0000000000000000000000000000000000000000..23738d9a5ff42638dd64cedb5b3d161bee1d5422 --- /dev/null +++ b/paddle/phi/tests/api/test_allocator.cu @@ -0,0 +1,73 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include "paddle/phi/api/include/context_pool.h" + +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/common/transform.h" +#include "paddle/phi/core/allocator.h" +#include "paddle/phi/core/device_context.h" + +TEST(Allocator, CPU) { + phi::Allocator* allocator = phi::GetAllocator(phi::CPUPlace()); + auto cpu_allocation = allocator->Allocate(sizeof(float) * 4); + float* cpu_buf = static_cast(cpu_allocation->ptr()); + ASSERT_NE(cpu_buf, nullptr); + cpu_buf[0] = 1.0f; + cpu_buf[1] = 2.0f; + cpu_buf[2] = 3.0f; + cpu_buf[3] = 4.0f; + for (size_t i = 0; i < 4; ++i) { + cpu_buf[i] = cpu_buf[i] + 1; + } + for (size_t i = 0; i < 4; ++i) { + ASSERT_NEAR(cpu_buf[i], static_cast(2.0 + i), 1e-5); + } +} + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +using paddle::memory::Copy; + +template +class Scale { + public: + explicit Scale(const T& scale) : scale_(scale) {} + HOSTDEVICE T operator()(const T& a) const { return a * scale_; } + + private: + T scale_; +}; + +TEST(Allocator, GPU) { + phi::GPUPlace gpu0(0); + float cpu_buf[4] = {0.1, 0.2, 0.3, 0.4}; + phi::Allocator* allocator = phi::GetAllocator(gpu0); + auto gpu_allocation = allocator->Allocate(sizeof(cpu_buf)); + float* gpu_buf = static_cast(gpu_allocation->ptr()); + + phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance(); + auto* ctx = reinterpret_cast(pool.Get(gpu0)); + Copy(gpu0, gpu_buf, phi::CPUPlace(), cpu_buf, sizeof(cpu_buf), ctx->stream()); + phi::Transform trans; + trans(*ctx, gpu_buf, gpu_buf + 4, gpu_buf, Scale(10)); + ctx->Wait(); + Copy(phi::CPUPlace(), cpu_buf, gpu0, gpu_buf, sizeof(cpu_buf), ctx->stream()); + for (int i = 0; i < 4; ++i) { + ASSERT_NEAR(cpu_buf[i], static_cast(i + 1), 1e-5); + } +} +#endif