未验证 提交 bc7632be 编写于 作者: 石晓伟 提交者: GitHub

upgrade inference tensor apis, test=develop (#31402)

上级 8491ae9a
......@@ -1195,20 +1195,6 @@ USE_TRT_CONVERTER(clip);
namespace paddle_infer {
void Tensor::Reshape(const std::vector<int> &shape) { tensor_->Reshape(shape); }
std::vector<int> Tensor::shape() const { return tensor_->shape(); }
void Tensor::SetLoD(const std::vector<std::vector<size_t>> &x) {
return tensor_->SetLoD(x);
}
std::vector<std::vector<size_t>> Tensor::lod() const { return tensor_->lod(); }
const std::string &Tensor::name() const { return tensor_->name(); }
DataType Tensor::type() const { return tensor_->type(); }
Predictor::Predictor(const Config &config) {
const_cast<Config *>(&config)->SwitchUseFeedFetchOps(false);
// The second parameter indicates that the discard log is not printed
......@@ -1221,9 +1207,7 @@ std::vector<std::string> Predictor::GetInputNames() {
}
std::unique_ptr<Tensor> Predictor::GetInputHandle(const std::string &name) {
auto zero_copy_tensor = predictor_->GetInputTensor(name);
std::unique_ptr<Tensor> tensor(new Tensor(std::move(zero_copy_tensor)));
return tensor;
return predictor_->GetInputTensor(name);
}
std::vector<std::string> Predictor::GetOutputNames() {
......@@ -1231,9 +1215,7 @@ std::vector<std::string> Predictor::GetOutputNames() {
}
std::unique_ptr<Tensor> Predictor::GetOutputHandle(const std::string &name) {
auto zero_copy_tensor = predictor_->GetOutputTensor(name);
std::unique_ptr<Tensor> tensor(new Tensor(std::move(zero_copy_tensor)));
return tensor;
return predictor_->GetOutputTensor(name);
}
bool Predictor::Run() { return predictor_->ZeroCopyRun(); }
......
......@@ -16,3 +16,5 @@
cc_library(reset_tensor_array SRCS reset_tensor_array.cc DEPS lod_tensor scope)
cc_library(zero_copy_tensor SRCS zero_copy_tensor.cc DEPS scope lod_tensor enforce)
cc_library(zero_copy_tensor_dummy SRCS zero_copy_tensor_dummy.cc)
cc_test(zero_copy_tensor_test SRCS zero_copy_tensor_test.cc DEPS paddle_inference_api)
......@@ -18,126 +18,135 @@
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace paddle_infer {
void ZeroCopyTensor::Reshape(const std::vector<int> &shape) {
void Tensor::Reshape(const std::vector<int> &shape) {
PADDLE_ENFORCE_EQ(
name_.empty(), false,
platform::errors::PreconditionNotMet(
paddle::platform::errors::PreconditionNotMet(
"Need to SetName first, so that the corresponding tensor can "
"be retrieved."));
PADDLE_ENFORCE_EQ(input_or_output_, true,
platform::errors::PermissionDenied(
paddle::platform::errors::PermissionDenied(
"Can't reshape the output tensor, it is readonly"));
PADDLE_ENFORCE_NOT_NULL(scope_, platform::errors::PreconditionNotMet(
"The scope should not be nullptr."));
auto *scope = static_cast<framework::Scope *>(scope_);
auto *scope = static_cast<paddle::framework::Scope *>(scope_);
auto *var = scope->FindVar(name_);
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::PreconditionNotMet(
var, paddle::platform::errors::PreconditionNotMet(
"No tensor called [%s] in the runtime scope", name_));
auto *tensor = var->GetMutable<framework::LoDTensor>();
tensor->Resize(framework::make_ddim(shape));
auto *tensor = var->GetMutable<paddle::framework::LoDTensor>();
tensor->Resize(paddle::framework::make_ddim(shape));
}
#define EAGER_GET_TENSOR \
if (!tensor_) { \
tensor_ = FindTensor(); \
} \
auto *tensor = static_cast<framework::LoDTensor *>(tensor_);
auto *tensor = static_cast<paddle::framework::LoDTensor *>(tensor_);
template <typename T>
T *ZeroCopyTensor::mutable_data(PaddlePlace place) {
T *Tensor::mutable_data(PlaceType place) {
EAGER_GET_TENSOR;
PADDLE_ENFORCE_GT(
tensor->numel(), 0,
platform::errors::PreconditionNotMet(
"You should call ZeroCopyTensor::Reshape(const std::vector<int> "
paddle::platform::errors::PreconditionNotMet(
"You should call Tensor::Reshape(const std::vector<int> "
"&shape)"
"function before retrieving mutable_data from input tensor."));
switch (static_cast<int>(place)) {
case static_cast<int>(PaddlePlace::kCPU): {
return tensor->mutable_data<T>(platform::CPUPlace());
case static_cast<int>(PlaceType::kCPU): {
return tensor->mutable_data<T>(paddle::platform::CPUPlace());
}
case static_cast<int>(PaddlePlace::kGPU): {
return tensor->mutable_data<T>(platform::CUDAPlace(device_));
case static_cast<int>(PlaceType::kGPU): {
return tensor->mutable_data<T>(paddle::platform::CUDAPlace(device_));
}
case static_cast<int>(PlaceType::kXPU): {
return tensor->mutable_data<T>(paddle::platform::XPUPlace(device_));
}
default:
PADDLE_THROW(platform::errors::Unavailable("Unsupported place: %d",
static_cast<int>(place)));
PADDLE_THROW(paddle::platform::errors::Unavailable(
"Only CPU / CUDA / XPU places is supported. The place `%d` is not "
"supported.",
static_cast<int>(place)));
break;
}
return nullptr;
}
template <typename T>
T *ZeroCopyTensor::data(PaddlePlace *place, int *size) const {
T *Tensor::data(PlaceType *place, int *size) const {
EAGER_GET_TENSOR;
auto *res = tensor->data<T>();
if (platform::is_cpu_place(tensor->place())) {
*place = PaddlePlace::kCPU;
} else if (platform::is_gpu_place(tensor->place())) {
*place = PaddlePlace::kGPU;
if (paddle::platform::is_cpu_place(tensor->place())) {
*place = PlaceType::kCPU;
} else if (paddle::platform::is_gpu_place(tensor->place())) {
*place = PlaceType::kGPU;
} else if (paddle::platform::is_xpu_place(tensor->place())) {
*place = PlaceType::kXPU;
} else {
*place = PaddlePlace::kUNK;
*place = PlaceType::kUNK;
}
*size = tensor->numel();
return res;
}
PaddleDType ZeroCopyTensor::type() const {
DataType Tensor::type() const {
EAGER_GET_TENSOR;
auto type = tensor->type();
if (type == framework::proto::VarType::FP32) {
return PaddleDType::FLOAT32;
} else if (type == framework::proto::VarType::INT64) {
return PaddleDType::INT64;
} else if (type == framework::proto::VarType::INT32) {
return PaddleDType::INT32;
} else if (type == framework::proto::VarType::UINT8) {
return PaddleDType::UINT8;
if (type == paddle::framework::proto::VarType::FP32) {
return DataType::FLOAT32;
} else if (type == paddle::framework::proto::VarType::INT64) {
return DataType::INT64;
} else if (type == paddle::framework::proto::VarType::INT32) {
return DataType::INT32;
} else if (type == paddle::framework::proto::VarType::UINT8) {
return DataType::UINT8;
}
return PaddleDType::FLOAT32;
return DataType::FLOAT32;
}
template <typename T>
void ZeroCopyTensor::copy_from_cpu(const T *data) {
void Tensor::CopyFromCpu(const T *data) {
EAGER_GET_TENSOR;
PADDLE_ENFORCE_GE(tensor->numel(), 0,
platform::errors::PreconditionNotMet(
"You should call ZeroCopyTensor::Reshape(const "
paddle::platform::errors::PreconditionNotMet(
"You should call Tensor::Reshape(const "
"std::vector<int> &shape)"
"function before copying data from cpu."));
size_t ele_size = tensor->numel() * sizeof(T);
if (place_ == PaddlePlace::kCPU) {
auto *t_data = tensor->mutable_data<T>(platform::CPUPlace());
if (place_ == PlaceType::kCPU) {
auto *t_data = tensor->mutable_data<T>(paddle::platform::CPUPlace());
std::memcpy(static_cast<void *>(t_data), data, ele_size);
} else if (place_ == PaddlePlace::kGPU) {
} else if (place_ == PlaceType::kGPU) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
platform::CUDAPlace gpu_place(device_);
paddle::platform::DeviceContextPool &pool =
paddle::platform::DeviceContextPool::Instance();
paddle::platform::CUDAPlace gpu_place(device_);
auto *t_data = tensor->mutable_data<T>(gpu_place);
auto *dev_ctx =
static_cast<const platform::CUDADeviceContext *>(pool.Get(gpu_place));
auto *dev_ctx = static_cast<const paddle::platform::CUDADeviceContext *>(
pool.Get(gpu_place));
memory::Copy(gpu_place, static_cast<void *>(t_data), platform::CPUPlace(),
data, ele_size, dev_ctx->stream());
paddle::memory::Copy(gpu_place, static_cast<void *>(t_data),
paddle::platform::CPUPlace(), data, ele_size,
dev_ctx->stream());
#else
PADDLE_THROW(platform::errors::Unavailable(
"Not compiled with CUDA, should not reach here."));
PADDLE_THROW(paddle::platform::errors::Unavailable(
"Can not create tensor with CUDA place because paddle is not compiled "
"with CUDA."));
#endif
} else if (place_ == PaddlePlace::kXPU) {
} else if (place_ == PlaceType::kXPU) {
#ifdef PADDLE_WITH_XPU
platform::XPUPlace xpu_place(device_);
paddle::platform::XPUPlace xpu_place(device_);
auto *t_data = tensor->mutable_data<T>(xpu_place);
memory::Copy(xpu_place, static_cast<void *>(t_data), platform::CPUPlace(),
data, ele_size);
paddle::memory::Copy(xpu_place, static_cast<void *>(t_data),
paddle::platform::CPUPlace(), data, ele_size);
#else
PADDLE_THROW(platform::errors::Unavailable(
"Not compiled with XPU, should not reach here."));
PADDLE_THROW(paddle::platform::errors::Unavailable(
"Can not create tensor with XPU place because paddle is not compiled "
"with XPU."));
#endif
} else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
......@@ -146,119 +155,119 @@ void ZeroCopyTensor::copy_from_cpu(const T *data) {
}
template <typename T>
void ZeroCopyTensor::copy_to_cpu(T *data) {
void Tensor::CopyToCpu(T *data) {
EAGER_GET_TENSOR;
auto ele_num = tensor->numel();
auto *t_data = tensor->data<T>();
auto t_place = tensor->place();
if (platform::is_cpu_place(t_place)) {
if (paddle::platform::is_cpu_place(t_place)) {
std::memcpy(static_cast<void *>(data), t_data, ele_num * sizeof(T));
} else if (place_ == PaddlePlace::kGPU) {
} else if (place_ == PlaceType::kGPU) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto gpu_place = BOOST_GET_CONST(platform::CUDAPlace, t_place);
auto *dev_ctx =
static_cast<const platform::CUDADeviceContext *>(pool.Get(gpu_place));
memory::Copy(platform::CPUPlace(), static_cast<void *>(data), gpu_place,
t_data, ele_num * sizeof(T), dev_ctx->stream());
paddle::platform::DeviceContextPool &pool =
paddle::platform::DeviceContextPool::Instance();
auto gpu_place = BOOST_GET_CONST(paddle::platform::CUDAPlace, t_place);
auto *dev_ctx = static_cast<const paddle::platform::CUDADeviceContext *>(
pool.Get(gpu_place));
paddle::memory::Copy(paddle::platform::CPUPlace(),
static_cast<void *>(data), gpu_place, t_data,
ele_num * sizeof(T), dev_ctx->stream());
#ifdef PADDLE_WITH_HIP
hipStreamSynchronize(dev_ctx->stream());
#else
cudaStreamSynchronize(dev_ctx->stream());
#endif
#else
PADDLE_THROW(platform::errors::Unavailable(
"Not compile with CUDA, should not reach here."));
PADDLE_THROW(paddle::platform::errors::Unavailable(
"Can not create tensor with CUDA place because paddle is not compiled "
"with CUDA."));
#endif
} else if (place_ == PaddlePlace::kXPU) {
} else if (place_ == PlaceType::kXPU) {
#ifdef PADDLE_WITH_XPU
auto xpu_place = BOOST_GET_CONST(platform::XPUPlace, t_place);
memory::Copy(platform::CPUPlace(), static_cast<void *>(data), xpu_place,
t_data, ele_num * sizeof(T));
auto xpu_place = BOOST_GET_CONST(paddle::platform::XPUPlace, t_place);
paddle::memory::Copy(paddle::platform::CPUPlace(),
static_cast<void *>(data), xpu_place, t_data,
ele_num * sizeof(T));
#else
PADDLE_THROW(platform::errors::Unavailable(
"Not compile with XPU, should not reach here."));
PADDLE_THROW(paddle::platform::errors::Unavailable(
"Can not create tensor with XPU place because paddle is not compiled "
"with XPU."));
#endif
} else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"The analysis predictor supports CPU, GPU and XPU now."));
}
}
template PD_INFER_DECL void ZeroCopyTensor::copy_from_cpu<float>(
const float *data);
template PD_INFER_DECL void ZeroCopyTensor::copy_from_cpu<int64_t>(
const int64_t *data);
template PD_INFER_DECL void ZeroCopyTensor::copy_from_cpu<int32_t>(
const int32_t *data);
template PD_INFER_DECL void ZeroCopyTensor::copy_from_cpu<uint8_t>(
const uint8_t *data);
template PD_INFER_DECL void ZeroCopyTensor::copy_from_cpu<int8_t>(
const int8_t *data);
template PD_INFER_DECL void Tensor::CopyFromCpu<float>(const float *data);
template PD_INFER_DECL void Tensor::CopyFromCpu<int64_t>(const int64_t *data);
template PD_INFER_DECL void Tensor::CopyFromCpu<int32_t>(const int32_t *data);
template PD_INFER_DECL void Tensor::CopyFromCpu<uint8_t>(const uint8_t *data);
template PD_INFER_DECL void Tensor::CopyFromCpu<int8_t>(const int8_t *data);
template PD_INFER_DECL void Tensor::CopyToCpu<float>(float *data);
template PD_INFER_DECL void Tensor::CopyToCpu<int64_t>(int64_t *data);
template PD_INFER_DECL void Tensor::CopyToCpu<int32_t>(int32_t *data);
template PD_INFER_DECL void Tensor::CopyToCpu<uint8_t>(uint8_t *data);
template PD_INFER_DECL void Tensor::CopyToCpu<int8_t>(int8_t *data);
template PD_INFER_DECL void ZeroCopyTensor::copy_to_cpu<float>(float *data);
template PD_INFER_DECL void ZeroCopyTensor::copy_to_cpu<int64_t>(int64_t *data);
template PD_INFER_DECL void ZeroCopyTensor::copy_to_cpu<int32_t>(int32_t *data);
template PD_INFER_DECL void ZeroCopyTensor::copy_to_cpu<uint8_t>(uint8_t *data);
template PD_INFER_DECL void ZeroCopyTensor::copy_to_cpu<int8_t>(int8_t *data);
template PD_INFER_DECL float *Tensor::data<float>(PlaceType *place,
int *size) const;
template PD_INFER_DECL int64_t *Tensor::data<int64_t>(PlaceType *place,
int *size) const;
template PD_INFER_DECL int32_t *Tensor::data<int32_t>(PlaceType *place,
int *size) const;
template PD_INFER_DECL uint8_t *Tensor::data<uint8_t>(PlaceType *place,
int *size) const;
template PD_INFER_DECL int8_t *Tensor::data<int8_t>(PlaceType *place,
int *size) const;
template PD_INFER_DECL float *ZeroCopyTensor::data<float>(PaddlePlace *place,
int *size) const;
template PD_INFER_DECL int64_t *ZeroCopyTensor::data<int64_t>(
PaddlePlace *place, int *size) const;
template PD_INFER_DECL int32_t *ZeroCopyTensor::data<int32_t>(
PaddlePlace *place, int *size) const;
template PD_INFER_DECL uint8_t *ZeroCopyTensor::data<uint8_t>(
PaddlePlace *place, int *size) const;
template PD_INFER_DECL int8_t *ZeroCopyTensor::data<int8_t>(PaddlePlace *place,
int *size) const;
template PD_INFER_DECL float *Tensor::mutable_data<float>(PlaceType place);
template PD_INFER_DECL int64_t *Tensor::mutable_data<int64_t>(PlaceType place);
template PD_INFER_DECL int32_t *Tensor::mutable_data<int32_t>(PlaceType place);
template PD_INFER_DECL uint8_t *Tensor::mutable_data<uint8_t>(PlaceType place);
template PD_INFER_DECL int8_t *Tensor::mutable_data<int8_t>(PlaceType place);
template PD_INFER_DECL float *ZeroCopyTensor::mutable_data<float>(
PaddlePlace place);
template PD_INFER_DECL int64_t *ZeroCopyTensor::mutable_data<int64_t>(
PaddlePlace place);
template PD_INFER_DECL int32_t *ZeroCopyTensor::mutable_data<int32_t>(
PaddlePlace place);
template PD_INFER_DECL uint8_t *ZeroCopyTensor::mutable_data<uint8_t>(
PaddlePlace place);
template PD_INFER_DECL int8_t *ZeroCopyTensor::mutable_data<int8_t>(
PaddlePlace place);
Tensor::Tensor(void *scope) : scope_{scope} {
PADDLE_ENFORCE_NOT_NULL(scope_,
paddle::platform::errors::PreconditionNotMet(
"The `scope` can not be nullptr. It should be "
"set to the pointer of scope."));
}
void *ZeroCopyTensor::FindTensor() const {
void *Tensor::FindTensor() const {
PADDLE_ENFORCE_EQ(
name_.empty(), false,
platform::errors::PreconditionNotMet(
paddle::platform::errors::PreconditionNotMet(
"Need to SetName first, so that the corresponding tensor can "
"be retrieved."));
PADDLE_ENFORCE_NOT_NULL(scope_, platform::errors::PreconditionNotMet(
"The scope should not be nullptr."));
auto *scope = static_cast<framework::Scope *>(scope_);
auto *scope = static_cast<paddle::framework::Scope *>(scope_);
auto *var = scope->FindVar(name_);
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::PreconditionNotMet(
var, paddle::platform::errors::PreconditionNotMet(
"No tensor called [%s] in the runtime scope", name_));
auto *tensor = var->GetMutable<framework::LoDTensor>();
auto *tensor = var->GetMutable<paddle::framework::LoDTensor>();
return tensor;
}
std::vector<int> ZeroCopyTensor::shape() const {
std::vector<int> Tensor::shape() const {
EAGER_GET_TENSOR;
PADDLE_ENFORCE_NOT_NULL(
tensor_, platform::errors::PreconditionNotMet(
tensor_, paddle::platform::errors::PreconditionNotMet(
"Not found tensor called %s in the scope", name_));
return framework::vectorize<int>(tensor->dims());
return paddle::framework::vectorize<int>(tensor->dims());
}
void ZeroCopyTensor::SetLoD(const std::vector<std::vector<size_t>> &x) {
void Tensor::SetLoD(const std::vector<std::vector<size_t>> &x) {
EAGER_GET_TENSOR;
framework::LoD lod;
paddle::framework::LoD lod;
for (auto &level : x) {
lod.emplace_back(level);
}
tensor->set_lod(lod);
}
std::vector<std::vector<size_t>> ZeroCopyTensor::lod() const {
std::vector<std::vector<size_t>> Tensor::lod() const {
EAGER_GET_TENSOR;
std::vector<std::vector<size_t>> res;
for (auto &level : tensor->lod()) {
......@@ -267,4 +276,13 @@ std::vector<std::vector<size_t>> ZeroCopyTensor::lod() const {
return res;
}
} // namespace paddle
void Tensor::SetName(const std::string &name) { name_ = name; }
const std::string &Tensor::name() const { return name_; }
void Tensor::SetPlace(PlaceType place, int device) {
place_ = place;
device_ = device;
}
} // namespace paddle_infer
......@@ -15,35 +15,35 @@
#include "paddle/fluid/inference/api/paddle_api.h"
#include "paddle/fluid/inference/api/paddle_infer_declare.h"
namespace paddle {
namespace paddle_infer {
void ZeroCopyTensor::Reshape(const std::vector<int> &shape) {}
void Tensor::Reshape(const std::vector<int> &shape) {}
template <typename T>
T *ZeroCopyTensor::mutable_data(PaddlePlace place) {
T *Tensor::mutable_data(PlaceType place) {
return nullptr;
}
template <typename T>
T *ZeroCopyTensor::data(PaddlePlace *place, int *size) const {
T *Tensor::data(PlaceType *place, int *size) const {
return nullptr;
}
template PD_INFER_DECL float *ZeroCopyTensor::data<float>(PaddlePlace *place,
int *size) const;
template PD_INFER_DECL int64_t *ZeroCopyTensor::data<int64_t>(
PaddlePlace *place, int *size) const;
template float *ZeroCopyTensor::mutable_data(PaddlePlace place);
template int64_t *ZeroCopyTensor::mutable_data(PaddlePlace place);
template PD_INFER_DECL float *Tensor::data<float>(PlaceType *place,
int *size) const;
template PD_INFER_DECL int64_t *Tensor::data<int64_t>(PlaceType *place,
int *size) const;
template float *Tensor::mutable_data(PlaceType place);
template int64_t *Tensor::mutable_data(PlaceType place);
void *ZeroCopyTensor::FindTensor() const { return nullptr; }
void *Tensor::FindTensor() const { return nullptr; }
std::vector<int> ZeroCopyTensor::shape() const { return {}; }
std::vector<int> Tensor::shape() const { return {}; }
void ZeroCopyTensor::SetLoD(const std::vector<std::vector<size_t>> &x) {}
void Tensor::SetLoD(const std::vector<std::vector<size_t>> &x) {}
std::vector<std::vector<size_t>> ZeroCopyTensor::lod() const {
std::vector<std::vector<size_t>> Tensor::lod() const {
return std::vector<std::vector<size_t>>();
}
} // namespace paddle
} // namespace paddle_infer
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <algorithm>
#include <functional>
#include <limits>
#include <memory>
#include <random>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/api/helper.h"
#include "paddle/fluid/inference/api/paddle_tensor.h"
#include "paddle/fluid/platform/place.h"
namespace paddle_infer {
struct TensorWrapper : public Tensor {
TensorWrapper(paddle_infer::PlaceType place, paddle::framework::Scope* scope,
const std::string& name)
: Tensor{static_cast<void*>(scope)} {
SetPlace(place, 0 /*device_id*/);
SetName(name);
input_or_output_ = true;
}
};
std::unique_ptr<Tensor> CreateTensor(paddle_infer::PlaceType place,
paddle::framework::Scope* scope,
const std::string& name) {
return std::unique_ptr<Tensor>(new TensorWrapper{place, scope, name});
}
template <typename T>
struct RandomGenerator {
RandomGenerator(double min = (std::numeric_limits<T>::min)(),
double max = (std::numeric_limits<T>::max)())
: dist_{static_cast<double>(min), static_cast<double>(max)} {}
T operator()() { return static_cast<T>(dist_(random_engine_)); }
private:
std::mt19937_64 random_engine_{std::random_device()()};
std::uniform_real_distribution<double> dist_;
};
template <typename T, template <typename> typename G>
bool FillRandomDataAndCheck(PlaceType place, size_t length, G<T>&& generator,
float threshold = 10e-5) {
std::vector<T> data_in(length);
std::generate(data_in.begin(), data_in.end(), std::forward<G<T>>(generator));
paddle::framework::Scope scope;
const std::string name{"name"};
scope.Var(name);
auto tensor = CreateTensor(place, &scope, name);
tensor->CopyFromCpu<T>(data_in.data());
if (tensor->type() != paddle::inference::ConvertToPaddleDType(
paddle::framework::DataTypeTrait<T>::DataType())) {
return false;
}
std::vector<T> data_out(length);
tensor->CopyToCpu<T>(data_out.data());
for (size_t i = 0; i < length; ++i) {
if (std::abs(data_out[i] - data_out[i]) > threshold) {
return false;
}
}
return true;
}
template <typename T>
bool SetPlaceAndCheck(PlaceType place, size_t length) {
paddle::framework::Scope scope;
const std::string name{"name"};
const std::vector<std::vector<size_t>> lod{{0, length}};
scope.Var(name);
auto tensor = CreateTensor(place, &scope, name);
tensor->Reshape({static_cast<int>(length)});
tensor->mutable_data<T>(place);
tensor->SetLoD(lod);
PlaceType place_out{PlaceType::kUNK};
int length_out{-1};
tensor->data<T>(&place_out, &length_out);
if (length_out != static_cast<int>(length) || place_out != place) {
return false;
}
if (tensor->name() != name || tensor->lod() != lod) {
return false;
}
return true;
}
bool FillRandomDataAndCheck(PlaceType place) {
const size_t length{RandomGenerator<size_t>{1, 1000}()};
VLOG(3) << "FillRandomDataAndCheck: length = " << length;
return FillRandomDataAndCheck<float>(place, length,
RandomGenerator<float>{}) &&
FillRandomDataAndCheck<int64_t>(place, length,
RandomGenerator<int64_t>{}) &&
FillRandomDataAndCheck<int32_t>(place, length,
RandomGenerator<int32_t>{}) &&
FillRandomDataAndCheck<uint8_t>(place, length,
RandomGenerator<uint8_t>{});
}
bool SetPlaceAndCheck(PlaceType place) {
const size_t length{RandomGenerator<size_t>{1, 1000}()};
VLOG(3) << "SetPlaceAndCheck: length = " << length;
return SetPlaceAndCheck<float>(place, length) &&
SetPlaceAndCheck<int64_t>(place, length) &&
SetPlaceAndCheck<int32_t>(place, length) &&
SetPlaceAndCheck<uint8_t>(place, length);
}
TEST(Tensor, FillRandomDataAndCheck) {
ASSERT_TRUE(FillRandomDataAndCheck(PlaceType::kCPU));
ASSERT_TRUE(SetPlaceAndCheck(PlaceType::kCPU));
#ifdef PADDLE_WITH_CUDA
ASSERT_TRUE(FillRandomDataAndCheck(PlaceType::kGPU));
ASSERT_TRUE(SetPlaceAndCheck(PlaceType::kGPU));
#endif
}
} // namespace paddle_infer
......@@ -58,6 +58,26 @@ constexpr PaddleDType PaddleTensorGetDType<float>() {
return PaddleDType::FLOAT32;
}
inline PaddleDType ConvertToPaddleDType(
paddle::framework::proto::VarType::Type type) {
if (type == paddle::framework::proto::VarType::FP32) {
return PaddleDType::FLOAT32;
} else if (type == paddle::framework::proto::VarType::INT64) {
return PaddleDType::INT64;
} else if (type == paddle::framework::proto::VarType::INT32) {
return PaddleDType::INT32;
} else if (type == paddle::framework::proto::VarType::UINT8) {
return PaddleDType::UINT8;
} else {
PADDLE_THROW(paddle::platform::errors::Unimplemented(
"The paddle dtype convert function only supports FLOAT32, INT64, INT32 "
"and UINT8 now. But "
"we get %d here.",
static_cast<int>(type)));
return PaddleDType::FLOAT32;
}
}
using paddle::framework::DataTypeToString;
// Timer for timer
......
......@@ -29,19 +29,13 @@
#include <vector>
#include "crypto/cipher.h"
#include "paddle_infer_declare.h" // NOLINT
#include "paddle_tensor.h" // NOLINT
/*! \namespace paddle
*/
namespace paddle {
/// \brief Paddle data type.
enum PaddleDType {
FLOAT32,
INT64,
INT32,
UINT8,
INT8,
// TODO(Superjomn) support more data types if needed.
};
using PaddleDType = paddle_infer::DataType;
using PaddlePlace = paddle_infer::PlaceType;
/// \brief Memory manager for PaddleTensor.
///
......@@ -162,8 +156,6 @@ struct PD_INFER_DECL PaddleTensor {
std::vector<std::vector<size_t>> lod; ///< Tensor+LoD equals LoDTensor
};
enum class PaddlePlace { kUNK = -1, kCPU, kGPU, kXPU };
/// \brief Represents an n-dimensional array of values.
/// The ZeroCopyTensor is used to store the input or output of the network.
/// Zero copy means that the tensor supports direct copy of host or device data
......@@ -172,79 +164,27 @@ enum class PaddlePlace { kUNK = -1, kCPU, kGPU, kXPU };
/// AnalysisPredictor.
/// It is obtained through PaddlePredictor::GetinputTensor()
/// and PaddlePredictor::GetOutputTensor() interface.
class PD_INFER_DECL ZeroCopyTensor {
public:
/// \brief Reset the shape of the tensor.
/// Generally it's only used for the input tensor.
/// Reshape must be called before calling mutable_data() or copy_from_cpu()
/// \param shape The shape to set.
void Reshape(const std::vector<int>& shape);
/// \brief Get the memory pointer in CPU or GPU with specific data type.
/// Please Reshape the tensor first before call this.
/// It's usually used to get input data pointer.
/// \param place The place of the tensor.
template <typename T>
T* mutable_data(PaddlePlace place);
/// \brief Get the memory pointer directly.
/// It's usually used to get the output data pointer.
/// \param[out] place To get the device type of the tensor.
/// \param[out] size To get the data size of the tensor.
/// \return The tensor data buffer pointer.
template <typename T>
T* data(PaddlePlace* place, int* size) const;
class PD_INFER_DECL ZeroCopyTensor : public paddle_infer::Tensor {
public:
/// \brief Copy the host memory to tensor data.
/// It's usually used to set the input tensor data.
/// \param data The pointer of the data, from which the tensor will copy.
template <typename T>
void copy_from_cpu(const T* data);
void copy_from_cpu(const T* data) {
return CopyFromCpu(data);
}
/// \brief Copy the tensor data to the host memory.
/// It's usually used to get the output tensor data.
/// \param[out] data The tensor will copy the data to the address.
template <typename T>
void copy_to_cpu(T* data);
/// \brief Return the shape of the Tensor.
std::vector<int> shape() const;
/// \brief Set lod info of the tensor.
/// More about LOD can be seen here:
/// https://www.paddlepaddle.org.cn/documentation/docs/zh/beginners_guide/basic_concept/lod_tensor.html#lodtensor
/// \param x the lod info.
void SetLoD(const std::vector<std::vector<size_t>>& x);
/// \brief Return the lod info of the tensor.
std::vector<std::vector<size_t>> lod() const;
/// \brief Return the name of the tensor.
const std::string& name() const { return name_; }
void SetPlace(PaddlePlace place, int device = -1) {
place_ = place;
device_ = device;
void copy_to_cpu(T* data) {
return CopyToCpu(data);
}
/// \brief Return the data type of the tensor.
/// It's usually used to get the output tensor data type.
/// \return The data type of the tensor.
PaddleDType type() const;
protected:
explicit ZeroCopyTensor(void* scope) : scope_{scope} {}
void SetName(const std::string& name) { name_ = name; }
void* FindTensor() const;
private:
std::string name_;
bool input_or_output_;
friend class AnalysisPredictor;
void* scope_{nullptr};
// The corresponding tensor pointer inside Paddle workspace is cached for
// performance.
mutable void* tensor_{nullptr};
PaddlePlace place_;
PaddleDType dtype_;
int device_;
explicit ZeroCopyTensor(void* scope) : paddle_infer::Tensor{scope} {}
};
/// \brief A Predictor for executing inference on a model.
......
......@@ -42,97 +42,10 @@ limitations under the License. */
///
namespace paddle_infer {
using DataType = paddle::PaddleDType;
using PlaceType = paddle::PaddlePlace;
using PrecisionType = paddle::AnalysisConfig::Precision;
using Config = paddle::AnalysisConfig;
///
/// \class Tensor
///
/// \brief Represents an n-dimensional array of values.
/// The Tensor is used to store the input or output of the network.
/// It is obtained through Predictor::GetinputHandle()
/// and Predictor::GetOutputHandle() interface.
///
class PD_INFER_DECL Tensor {
public:
// Can only be created by predictor->GetInputHandle(cosnt std::string& name)
// or predictor->GetOutputHandle(cosnt std::string& name)
Tensor() = delete;
explicit Tensor(std::unique_ptr<paddle::ZeroCopyTensor>&& tensor)
: tensor_(std::move(tensor)) {}
///
/// \brief Reset the shape of the tensor.
/// Generally it's only used for the input tensor.
/// Reshape must be called before calling mutable_data() or CopyFromCpu()
/// \param shape The shape to set.
///
void Reshape(const std::vector<int>& shape);
///
/// \brief Copy the host memory to tensor data.
/// It's usually used to set the input tensor data.
/// \param data The pointer of the data, from which the tensor will copy.
///
template <typename T>
void CopyFromCpu(const T* data);
///
/// \brief Get the memory pointer in CPU or GPU with specific data type.
/// Please Reshape the tensor first before call this.
/// It's usually used to get input data pointer.
/// \param place The place of the tensor.
/// \return The tensor data buffer pointer.
///
template <typename T>
T* mutable_data(PlaceType place);
///
/// \brief Copy the tensor data to the host memory.
/// It's usually used to get the output tensor data.
/// \param[out] data The tensor will copy the data to the address.
///
template <typename T>
void CopyToCpu(T* data);
///
/// \brief Get the memory pointer directly.
/// It's usually used to get the output data pointer.
/// \param[out] place To get the device type of the tensor.
/// \param[out] size To get the data size of the tensor.
/// \return The tensor data buffer pointer.
///
template <typename T>
T* data(PlaceType* place, int* size) const;
///
/// \brief Set lod info of the tensor.
/// More about LOD can be seen here:
/// https://www.paddlepaddle.org.cn/documentation/docs/zh/beginners_guide/basic_concept/lod_tensor.html#lodtensor
/// \param x the lod info.
///
void SetLoD(const std::vector<std::vector<size_t>>& x);
/// \brief Return the lod info of the tensor.
std::vector<std::vector<size_t>> lod() const;
/// \brief Return the data type of the tensor.
/// It's usually used to get the output tensor data type.
/// \return The data type of the tensor.
DataType type() const;
/// \brief Return the shape of the Tensor.
std::vector<int> shape() const;
/// \brief Return the name of the tensor.
const std::string& name() const;
private:
std::unique_ptr<paddle::ZeroCopyTensor> tensor_;
};
///
/// \class Predictor
///
......@@ -258,31 +171,7 @@ PD_INFER_DECL int GetNumBytesOfDataType(DataType dtype);
PD_INFER_DECL std::string GetVersion();
PD_INFER_DECL std::string UpdateDllFlag(const char* name, const char* value);
template <typename T>
void Tensor::CopyFromCpu(const T* data) {
tensor_->copy_from_cpu<T>(data);
}
template <typename T>
void Tensor::CopyToCpu(T* data) {
return tensor_->copy_to_cpu<T>(data);
}
template <typename T>
T* Tensor::mutable_data(PlaceType place) {
return tensor_->mutable_data<T>(place);
}
template <typename T>
T* Tensor::data(PlaceType* place, int* size) const {
return tensor_->data<T>(place, size);
}
} // namespace paddle_infer
namespace paddle_infer {
namespace services {
///
/// \class PredictorPool
///
......@@ -308,4 +197,5 @@ class PD_INFER_DECL PredictorPool {
std::vector<std::unique_ptr<Predictor>> preds_;
};
} // namespace services
} // namespace paddle_infer
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle_infer_declare.h" // NOLINT
namespace paddle_infer {
/// \brief Paddle data type.
enum DataType {
FLOAT32,
INT64,
INT32,
UINT8,
INT8,
// TODO(Superjomn) support more data types if needed.
};
enum class PlaceType { kUNK = -1, kCPU, kGPU, kXPU };
/// \brief Represents an n-dimensional array of values.
/// The Tensor is used to store the input or output of the network.
/// Zero copy means that the tensor supports direct copy of host or device data
/// to device,
/// eliminating additional CPU copy. Tensor is only used in the
/// AnalysisPredictor.
/// It is obtained through PaddlePredictor::GetinputTensor()
/// and PaddlePredictor::GetOutputTensor() interface.
class PD_INFER_DECL Tensor {
public:
/// \brief Reset the shape of the tensor.
/// Generally it's only used for the input tensor.
/// Reshape must be called before calling mutable_data() or copy_from_cpu()
/// \param shape The shape to set.
void Reshape(const std::vector<int>& shape);
/// \brief Get the memory pointer in CPU or GPU with specific data type.
/// Please Reshape the tensor first before call this.
/// It's usually used to get input data pointer.
/// \param place The place of the tensor.
template <typename T>
T* mutable_data(PlaceType place);
/// \brief Get the memory pointer directly.
/// It's usually used to get the output data pointer.
/// \param[out] place To get the device type of the tensor.
/// \param[out] size To get the data size of the tensor.
/// \return The tensor data buffer pointer.
template <typename T>
T* data(PlaceType* place, int* size) const;
/// \brief Copy the host memory to tensor data.
/// It's usually used to set the input tensor data.
/// \param data The pointer of the data, from which the tensor will copy.
template <typename T>
void CopyFromCpu(const T* data);
/// \brief Copy the tensor data to the host memory.
/// It's usually used to get the output tensor data.
/// \param[out] data The tensor will copy the data to the address.
template <typename T>
void CopyToCpu(T* data);
/// \brief Return the shape of the Tensor.
std::vector<int> shape() const;
/// \brief Set lod info of the tensor.
/// More about LOD can be seen here:
/// https://www.paddlepaddle.org.cn/documentation/docs/zh/beginners_guide/basic_concept/lod_tensor.html#lodtensor
/// \param x the lod info.
void SetLoD(const std::vector<std::vector<size_t>>& x);
/// \brief Return the lod info of the tensor.
std::vector<std::vector<size_t>> lod() const;
/// \brief Return the name of the tensor.
const std::string& name() const;
/// \brief Return the data type of the tensor.
/// It's usually used to get the output tensor data type.
/// \return The data type of the tensor.
DataType type() const;
protected:
explicit Tensor(void* scope);
void* FindTensor() const;
void SetPlace(PlaceType place, int device = -1);
void SetName(const std::string& name);
std::string name_;
// The corresponding tensor pointer inside Paddle workspace is cached for
// performance.
mutable void* tensor_{nullptr};
DataType dtype_;
bool input_or_output_;
void* scope_{nullptr};
PlaceType place_;
int device_;
};
} // namespace paddle_infer
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册