未验证 提交 1ff1c1e0 编写于 作者: J JingZhuangzhuang 提交者: GitHub

add share external data interface (#39809)

上级 e4dba69a
......@@ -13,6 +13,9 @@
// limitations under the License.
#include "paddle/fluid/inference/api/analysis_predictor.h"
#if defined(PADDLE_WITH_CUDA)
#include <cuda_runtime.h>
#endif
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <thread> // NOLINT
......@@ -405,4 +408,83 @@ TEST(Predictor, Run) {
predictor->TryShrinkMemory();
}
TEST(Tensor, CpuShareExternalData) {
Config config;
config.SetModel(FLAGS_dirname);
auto predictor = CreatePredictor(config);
auto w0 = predictor->GetInputHandle("firstw");
auto w1 = predictor->GetInputHandle("secondw");
auto w2 = predictor->GetInputHandle("thirdw");
auto w3 = predictor->GetInputHandle("forthw");
std::vector<std::vector<int64_t>> input_data(4, {0, 1, 2, 3});
w0->ShareExternalData<int64_t>(input_data[0].data(), {4, 1}, PlaceType::kCPU);
w1->ShareExternalData<int64_t>(input_data[1].data(), {4, 1}, PlaceType::kCPU);
w2->ShareExternalData<int64_t>(input_data[2].data(), {4, 1}, PlaceType::kCPU);
w3->ShareExternalData<int64_t>(input_data[3].data(), {4, 1}, PlaceType::kCPU);
auto out = predictor->GetOutputHandle("fc_1.tmp_2");
auto out_shape = out->shape();
std::vector<float> out_data;
out_data.resize(std::accumulate(out_shape.begin(), out_shape.end(), 1,
std::multiplies<int>()));
out->ShareExternalData<float>(out_data.data(), out_shape, PlaceType::kCPU);
predictor->Run();
PlaceType place;
int size = 0;
out->data<float>(&place, &size);
LOG(INFO) << "output size: " << size / sizeof(float);
predictor->TryShrinkMemory();
}
#if defined(PADDLE_WITH_CUDA)
TEST(Tensor, GpuShareExternalData) {
Config config;
config.SetModel(FLAGS_dirname);
config.EnableUseGpu(100, 0);
auto predictor = CreatePredictor(config);
auto w0 = predictor->GetInputHandle("firstw");
auto w1 = predictor->GetInputHandle("secondw");
auto w2 = predictor->GetInputHandle("thirdw");
auto w3 = predictor->GetInputHandle("forthw");
std::vector<std::vector<int64_t>> input_data(4, {0, 1, 2, 3});
std::vector<int64_t*> input_gpu(4, nullptr);
for (size_t i = 0; i < 4; ++i) {
cudaMalloc(reinterpret_cast<void**>(&input_gpu[i]), 4 * sizeof(int64_t));
cudaMemcpy(input_gpu[i], input_data[i].data(), 4 * sizeof(int64_t),
cudaMemcpyHostToDevice);
}
w0->ShareExternalData<int64_t>(input_gpu[0], {4, 1}, PlaceType::kGPU);
w1->ShareExternalData<int64_t>(input_gpu[1], {4, 1}, PlaceType::kGPU);
w2->ShareExternalData<int64_t>(input_gpu[2], {4, 1}, PlaceType::kGPU);
w3->ShareExternalData<int64_t>(input_gpu[3], {4, 1}, PlaceType::kGPU);
auto out = predictor->GetOutputHandle("fc_1.tmp_2");
auto out_shape = out->shape();
float* out_data;
auto out_size = std::accumulate(out_shape.begin(), out_shape.end(), 1,
std::multiplies<int>()) *
sizeof(float);
cudaMalloc(reinterpret_cast<void**>(out_data), out_size * sizeof(float));
out->ShareExternalData<float>(out_data, out_shape, PlaceType::kGPU);
predictor->Run();
PlaceType place;
int size = 0;
out->data<float>(&place, &size);
LOG(INFO) << "output size: " << size / sizeof(float);
predictor->TryShrinkMemory();
}
#endif
} // namespace paddle_infer
......@@ -21,6 +21,7 @@
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/core/allocator.h"
namespace paddle_infer {
......@@ -205,6 +206,73 @@ void Tensor::CopyFromCpu(const T *data) {
}
}
template <typename T>
struct DataTypeInfo;
template <>
struct DataTypeInfo<float> {
paddle::experimental::DataType TYPE = paddle::experimental::DataType::FLOAT32;
};
template <>
struct DataTypeInfo<float16> {
paddle::experimental::DataType TYPE = paddle::experimental::DataType::FLOAT16;
};
template <>
struct DataTypeInfo<int64_t> {
paddle::experimental::DataType TYPE = paddle::experimental::DataType::INT64;
};
template <>
struct DataTypeInfo<int8_t> {
paddle::experimental::DataType TYPE = paddle::experimental::DataType::INT8;
};
template <>
struct DataTypeInfo<uint8_t> {
paddle::experimental::DataType TYPE = paddle::experimental::DataType::UINT8;
};
template <>
struct DataTypeInfo<int32_t> {
paddle::experimental::DataType TYPE = paddle::experimental::DataType::INT32;
};
paddle::experimental::DataLayout LayoutConvert(DataLayout layout) {
PADDLE_ENFORCE_EQ(
layout, DataLayout::kNCHW,
paddle::platform::errors::InvalidArgument("Only NCHW is supported now."));
return paddle::experimental::DataLayout::NCHW;
}
template <typename T>
void Tensor::ShareExternalData(const T *data, const std::vector<int> &shape,
PlaceType place, DataLayout layout) {
EAGER_GET_TENSOR(paddle::framework::LoDTensor)
size_t size =
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>()) *
sizeof(T);
phi::DenseTensorMeta meta(DataTypeInfo<T>().TYPE, phi::make_ddim(shape),
LayoutConvert(layout));
if (place == PlaceType::kCPU) {
phi::DenseTensor dtensor(
std::make_shared<phi::Allocation>(const_cast<T *>(data), size,
paddle::platform::CPUPlace()),
meta);
*tensor = std::move(dtensor);
} else if (place == PlaceType::kGPU) {
phi::DenseTensor dtensor(
std::make_shared<phi::Allocation>(const_cast<T *>(data), size,
paddle::platform::CUDAPlace(device_)),
meta);
*tensor = std::move(dtensor);
} else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"PlaceType must be PlaceType::kCPU or PlaceType::kGPU."));
}
}
void Tensor::CopyStringsFromCpu(const paddle_infer::Strings *data) {
EAGER_GET_TENSOR(paddle_infer::Strings);
PADDLE_ENFORCE_GE(tensor->size(), 0,
......@@ -334,6 +402,25 @@ template PD_INFER_DECL void Tensor::CopyFromCpu<uint8_t>(const uint8_t *data);
template PD_INFER_DECL void Tensor::CopyFromCpu<int8_t>(const int8_t *data);
template PD_INFER_DECL void Tensor::CopyFromCpu<float16>(const float16 *data);
template PD_INFER_DECL void Tensor::ShareExternalData<float>(
const float *data, const std::vector<int> &shape, PlaceType place,
DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<int64_t>(
const int64_t *data, const std::vector<int> &shape, PlaceType place,
DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<int32_t>(
const int32_t *data, const std::vector<int> &shape, PlaceType place,
DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<uint8_t>(
const uint8_t *data, const std::vector<int> &shape, PlaceType place,
DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<int8_t>(
const int8_t *data, const std::vector<int> &shape, PlaceType place,
DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<float16>(
const float16 *data, const std::vector<int> &shape, PlaceType place,
DataLayout layout);
template PD_INFER_DECL void Tensor::CopyToCpu<float>(float *data) const;
template PD_INFER_DECL void Tensor::CopyToCpu<int64_t>(int64_t *data) const;
template PD_INFER_DECL void Tensor::CopyToCpu<int32_t>(int32_t *data) const;
......
......@@ -47,6 +47,8 @@ enum DataType {
enum class PlaceType { kUNK = -1, kCPU, kGPU, kXPU, kNPU, kIPU };
enum class DataLayout { kUNK = -1, kAny, kNHWC, kNCHW };
/// \brief Represents an n-dimensional array of values.
/// The Tensor is used to store the input or output of the network.
/// Zero copy means that the tensor supports direct copy of host or device data
......@@ -92,6 +94,17 @@ class PD_INFER_DECL Tensor {
template <typename T>
void CopyFromCpu(const T* data);
/// \brief Share the data with tensor data.
/// It's usually used to set the tensor data.
/// \param data The pointer of the data, from which the tensor will share.
/// \param shape The shape of data.
/// \param place The place of data.
/// \param layout The layout of data. Only NCHW is supported now.
template <typename T>
void ShareExternalData(const T* data, const std::vector<int>& shape,
PlaceType place,
DataLayout layout = DataLayout::kNCHW);
/// \brief Experimental interface.
/// It's usually used to set the input tensor data with Strings data type.
/// \param data The pointer of the data, from which the tensor will copy.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册