未验证 提交 74442f5e 编写于 作者: H Huang Jiyi 提交者: GitHub

Add from_blob api for constructing tensor from data pointer (#51085)

* add from_blob

* fix test

* fix test

* fix codestyle

* add gpu test

* fix test

* update

* add comment

* fix comment

* update comment

* fix CI bug

* add thread_local

* update

* fix bug

* fix bug

* fix bug

* fix bug

* fix bug

* fix bug

* fix cmake

* fix CI-Py3 make

* update

* use api_reg

* fix include

* update

* update

* update

* fix bug

* fix bug

* fix bug

* fix bug
上级 fdcfa04f
......@@ -323,7 +323,8 @@ cc_library(
api_gen_utils
phi_data_transform
api_custom_impl
phi_profiler)
phi_profiler
from_blob)
cc_library(
phi_bw_function_api
SRCS ${bw_api_source_file}
......@@ -390,6 +391,10 @@ cc_library(
api_int_array
SRCS int_array.cc
DEPS tensor_copy)
cc_library(
from_blob
SRCS from_blob.cc
DEPS phi_tensor_raw)
cc_library(
phi_tensor_operants
......
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/api/lib/from_blob.h"
#include "paddle/phi/api/lib/api_registry.h"
#include "paddle/phi/core/dense_tensor.h"
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#ifdef PADDLE_WITH_CUDA
#include <cuda_runtime.h>
#else
#include <hip/hip_runtime.h>
#endif
#endif
namespace paddle {
namespace experimental {
PD_REGISTER_API(from_blob)
phi::Place GetPlaceFromPtr(void* data);
using AllocationDeleter = void (*)(phi::Allocation*);
PADDLE_API Tensor from_blob(void* data,
const phi::DDim& shape,
phi::DataType dtype,
phi::DataLayout layout,
const phi::Place& place,
const Deleter& deleter) {
PADDLE_ENFORCE_NOT_NULL(
data, phi::errors::InvalidArgument("data can not be nullptr"));
auto data_place = GetPlaceFromPtr(data);
// TODO(huangjiyi): We need copy data to specified place when
// the input place is different with place of data.
if (place.GetType() != phi::AllocationType::UNDEFINED) {
PADDLE_ENFORCE_EQ(
data_place,
place,
phi::errors::InvalidArgument("Specified ",
data_place.DebugString(),
" does not match place of data ",
place.DebugString()));
}
auto meta = phi::DenseTensorMeta(dtype, shape, layout);
size_t size = SizeOf(dtype) * (meta.is_scalar ? 1 : product(meta.dims));
AllocationDeleter alloc_deleter = nullptr;
if (deleter) {
static thread_local Deleter g_deleter = deleter;
alloc_deleter = [](phi::Allocation* p) { g_deleter(p->ptr()); };
}
auto alloc =
std::make_shared<phi::Allocation>(data, size, alloc_deleter, data_place);
return Tensor(std::make_shared<phi::DenseTensor>(alloc, meta));
}
phi::Place GetPlaceFromPtr(void* data) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 10000
cudaPointerAttributes attr;
cudaError_t status = cudaPointerGetAttributes(&attr, data);
if (status == cudaSuccess && attr.type == cudaMemoryTypeDevice) {
return phi::GPUPlace(attr.device);
}
#else
PADDLE_THROW(
phi::errors::Unimplemented("The GetPlaceFromPtr() method is only "
"supported when CUDA version >= 10.0."));
#endif
#else
hipPointerAttribute_t attr;
hipError_t status = hipPointerGetAttributes(&attr, data);
if (status == hipSuccess && attr.memoryType == hipMemoryTypeDevice) {
return phi::GPUPlace(attr.device);
}
#endif
#endif
return phi::CPUPlace();
}
} // namespace experimental
} // namespace paddle
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <functional>
#include "paddle/phi/api/include/tensor.h"
namespace paddle {
namespace experimental {
using Deleter = std::function<void(void*)>;
/**
* @brief Construct a Tensor from a buffer pointed to by `data`
*
* @note `from_blob` doesn’t copy or move data, Modifying the constructed tensor
* is equivalent to modifying the original data.
*
* @param data The pointer to the memory buffer.
* @param shape The dims of the tensor.
* @param dtype The data type of the tensor, should correspond to data type of
* `data`. See PD_FOR_EACH_DATA_TYPE in phi/common/data_type.h
* @param layout The data layout of the tensor.
* @param place The place where the tensor is located, should correspond to
* place of `data`.
* @param deleter A function or function object that will be called to free the
* memory buffer.
*
* @return A Tensor object constructed from the buffer
*/
PADDLE_API Tensor from_blob(void* data,
const phi::DDim& shape,
phi::DataType dtype,
phi::DataLayout layout = phi::DataLayout::NCHW,
const phi::Place& place = phi::Place(),
const Deleter& deleter = nullptr);
} // namespace experimental
} // namespace paddle
......@@ -363,7 +363,9 @@ def source_include(header_file_path):
#include "paddle/phi/api/lib/api_custom_impl.h"
#include "paddle/phi/api/lib/api_gen_utils.h"
#include "paddle/phi/api/lib/api_registry.h"
#include "paddle/phi/api/lib/data_transform.h"
#include "paddle/phi/api/lib/from_blob.h"
#include "paddle/phi/api/lib/kernel_dispatch.h"
#include "paddle/phi/common/type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
......@@ -396,6 +398,13 @@ namespace experimental {
)
def declare_extension_api():
return """
PD_DECLARE_API(from_blob);
"""
def generate_api(api_yaml_path, header_file_path, source_file_path):
apis = []
......@@ -426,6 +435,8 @@ def generate_api(api_yaml_path, header_file_path, source_file_path):
header_file.write(foward_api.gene_api_declaration())
source_file.write(foward_api.gene_api_code())
source_file.write(declare_extension_api())
header_file.write(namespace[1])
source_file.write(namespace[1])
......
......@@ -21,7 +21,7 @@ limitations under the License. */
namespace phi {
/* --------------------------- */
/* From phi::DenseTensor */
/* From phi::DenseTensor */
/* --------------------------- */
DenseTensor::DenseTensor() {
meta_.dtype = paddle::experimental::DataType::FLOAT32;
......
......@@ -8,11 +8,15 @@ if(WITH_GPU)
nv_test(
test_allocator
SRCS test_allocator.cu
DEPS memory place device_context context_pool)
DEPS place device_context context_pool)
nv_test(
test_cuda_stream
SRCS test_cuda_stream.cu
DEPS context_pool)
nv_test(
test_from_blob
SRCS test_from_blob.cc
DEPS phi_backends ${COMMON_API_TEST_DEPS})
elseif(WITH_ROCM)
hip_test(
test_phi_tensor
......@@ -21,16 +25,24 @@ elseif(WITH_ROCM)
hip_test(
test_allocator
SRCS test_allocator.cu
DEPS memory place device_context context_pool)
DEPS place device_context context_pool)
hip_test(
test_cuda_stream
SRCS test_cuda_stream.cu
DEPS context_pool)
hip_test(
test_from_blob
SRCS test_from_blob.cc
DEPS phi_backends ${COMMON_API_TEST_DEPS})
else()
cc_test(
test_phi_tensor
SRCS test_phi_tensor.cc
DEPS glog selected_rows ${COMMON_API_TEST_DEPS})
cc_test(
test_from_blob
SRCS test_from_blob.cc
DEPS phi_backends ${COMMON_API_TEST_DEPS})
endif()
cc_test(
......
......@@ -16,14 +16,14 @@ limitations under the License. */
#include "paddle/phi/api/include/context_pool.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/phi/backends/context_pool.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/common/transform.h"
#include "paddle/phi/core/allocator.h"
#include "paddle/phi/core/device_context.h"
using paddle::memory::Copy;
using phi::memory_utils::Copy;
template <typename T>
class Scale {
......
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/phi/api/include/api.h"
#include "paddle/phi/api/lib/from_blob.h"
#include "paddle/phi/core/kernel_registry.h"
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/phi/api/include/context_pool.h"
#include "paddle/phi/backends/context_pool.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/common/memory_utils.h"
#endif
PD_DECLARE_KERNEL(pow, CPU, ALL_LAYOUT);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_DECLARE_KERNEL(pow, GPU, ALL_LAYOUT);
#endif
using paddle::experimental::DataType;
using paddle::experimental::from_blob;
namespace paddle {
namespace experimental {
phi::Place GetPlaceFromPtr(void* data);
} // namespace experimental
} // namespace paddle
TEST(from_blob, CPU) {
// 1. create data
int64_t data[] = {4, 3, 2, 1};
ASSERT_EQ(paddle::experimental::GetPlaceFromPtr(data), phi::CPUPlace());
// 2. test API
auto test_tesnor = from_blob(data, {1, 2, 2}, DataType::INT64);
// 3. check result
// 3.1 check tensor attributes
ASSERT_EQ(test_tesnor.dims().size(), 3);
ASSERT_EQ(test_tesnor.dims()[0], 1);
ASSERT_EQ(test_tesnor.dims()[1], 2);
ASSERT_EQ(test_tesnor.dims()[2], 2);
ASSERT_EQ(test_tesnor.numel(), 4);
ASSERT_EQ(test_tesnor.is_cpu(), true);
ASSERT_EQ(test_tesnor.dtype(), DataType::INT64);
ASSERT_EQ(test_tesnor.layout(), phi::DataLayout::NCHW);
ASSERT_EQ(test_tesnor.is_dense_tensor(), true);
// 3.2 check tensor values
auto* test_tensor_data = test_tesnor.template data<int64_t>();
for (int64_t i = 0; i < 4; i++) {
ASSERT_EQ(test_tensor_data[i], 4 - i);
}
// 3.3 check whether memory is shared
ASSERT_EQ(data, test_tensor_data);
// 3.4 test other API
auto test_tensor_pow = paddle::experimental::pow(test_tesnor, 2);
auto* test_tensor_pow_data = test_tensor_pow.template data<int64_t>();
for (int64_t i = 0; i < 4; i++) {
ASSERT_EQ(test_tensor_pow_data[i],
static_cast<int64_t>(std::pow(4 - i, 2)));
}
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
using phi::memory_utils::Copy;
TEST(GetPlaceFromPtr, GPU) {
using paddle::experimental::GetPlaceFromPtr;
float cpu_data[6];
auto cpu_data_place = GetPlaceFromPtr(cpu_data);
ASSERT_EQ(cpu_data_place, phi::CPUPlace());
std::cout << "cpu_data_place: " << cpu_data_place << std::endl;
float* gpu0_data = static_cast<float*>(paddle::GetAllocator(phi::GPUPlace(0))
->Allocate(sizeof(cpu_data))
->ptr());
auto gpu0_data_place = GetPlaceFromPtr(gpu0_data);
ASSERT_EQ(gpu0_data_place, phi::GPUPlace(0));
std::cout << "gpu0_data_place: " << gpu0_data_place << std::endl;
if (phi::backends::gpu::GetGPUDeviceCount() > 1) {
float* gpu1_data =
static_cast<float*>(paddle::GetAllocator(phi::GPUPlace(1))
->Allocate(sizeof(cpu_data))
->ptr());
auto gpu1_data_place = GetPlaceFromPtr(gpu1_data);
ASSERT_EQ(gpu1_data_place, phi::GPUPlace(1));
std::cout << "gpu1_data_place: " << gpu1_data_place << std::endl;
}
}
TEST(from_blob, GPU) {
// 1. create data
float cpu_data[6] = {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f};
phi::GPUPlace gpu0(0);
phi::Allocator* allocator = paddle::GetAllocator(gpu0);
auto gpu_allocation = allocator->Allocate(sizeof(cpu_data));
float* gpu_data = static_cast<float*>(gpu_allocation->ptr());
phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance();
auto* ctx = reinterpret_cast<phi::GPUContext*>(pool.Get(gpu0));
Copy(gpu0,
gpu_data,
phi::CPUPlace(),
cpu_data,
sizeof(cpu_data),
ctx->stream());
// 2. test API
auto gpu_tesnor = from_blob(gpu_data, {2, 3}, DataType::FLOAT32);
// 3. check result
// 3.1 check tensor attributes
ASSERT_EQ(gpu_tesnor.dims().size(), 2);
ASSERT_EQ(gpu_tesnor.dims()[0], 2);
ASSERT_EQ(gpu_tesnor.dims()[1], 3);
ASSERT_EQ(gpu_tesnor.numel(), 6);
// ASSERT_EQ(gpu_tesnor.is_gpu(), true);
ASSERT_EQ(gpu_tesnor.dtype(), DataType::FLOAT32);
// 3.2 check tensor values
auto* gpu_tesnor_data = gpu_tesnor.template data<float>();
float gpu_tesnor_data_cpu[6];
Copy(phi::CPUPlace(),
gpu_tesnor_data_cpu,
gpu0,
gpu_tesnor_data,
sizeof(cpu_data),
ctx->stream());
for (int64_t i = 0; i < 6; i++) {
ASSERT_NEAR(
gpu_tesnor_data_cpu[i], static_cast<float>((i + 1) * 0.1f), 1e-5);
}
// 3.3 check whether memory is shared
ASSERT_EQ(gpu_data, gpu_tesnor_data);
// 3.4 test other API
auto gpu_tesnor_pow = paddle::experimental::pow(gpu_tesnor, 2);
auto* gpu_tesnor_pow_data = gpu_tesnor_pow.template data<float>();
float gpu_tesnor_pow_data_cpu[6];
Copy(phi::CPUPlace(),
gpu_tesnor_pow_data_cpu,
gpu0,
gpu_tesnor_pow_data,
sizeof(cpu_data),
ctx->stream());
for (int64_t i = 0; i < 6; i++) {
ASSERT_NEAR(gpu_tesnor_pow_data_cpu[i],
static_cast<float>(std::pow(i + 1, 2) * 0.01f),
1e-5);
}
}
#endif
TEST(from_blob, Option) {
// 1. create data
auto data = new int64_t[8];
for (int64_t i = 0; i < 8; i++) {
data[i] = i;
}
// 2. test Deleter and Layout
int isdelete = 0;
auto deleter = [&isdelete](void* data) {
delete[] static_cast<int64_t*>(data);
isdelete++;
};
{
auto test_tesnor = from_blob(data,
{1, 2, 2, 1},
DataType::INT64,
phi::DataLayout::NHWC,
phi::CPUPlace(),
deleter);
// check tensor attributes
ASSERT_EQ(test_tesnor.layout(), phi::DataLayout::NHWC); // check layout
// check deleter
ASSERT_EQ(isdelete, 0);
}
ASSERT_EQ(isdelete, 1);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册