// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/data_layout_transform.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/inference/api/paddle_inference_api.h" #include "paddle/fluid/inference/api/paddle_tensor.h" #include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/float16.h" #include "paddle/phi/core/allocator.h" #ifdef PADDLE_WITH_ONNXRUNTIME #include "onnxruntime_c_api.h" // NOLINT #include "onnxruntime_cxx_api.h" // NOLINT #endif namespace paddle_infer { using float16 = paddle::platform::float16; void Tensor::Reshape(const std::vector &shape) { #ifdef PADDLE_WITH_ONNXRUNTIME if (is_ort_tensor_) { shape_.assign(shape.begin(), shape.end()); return; } #endif PADDLE_ENFORCE_EQ( name_.empty(), false, paddle::platform::errors::PreconditionNotMet( "Need to SetName first, so that the corresponding tensor can " "be retrieved.")); PADDLE_ENFORCE_EQ(input_or_output_, true, paddle::platform::errors::PermissionDenied( "Can't reshape the output tensor, it is readonly")); auto *scope = static_cast(scope_); auto *var = scope->FindVar(name_); PADDLE_ENFORCE_NOT_NULL( var, paddle::platform::errors::PreconditionNotMet( "No tensor called [%s] in the runtime scope", name_)); auto *tensor = var->GetMutable(); tensor->Resize(phi::make_ddim(shape)); } void Tensor::ReshapeStrings(const size_t &shape) { PADDLE_ENFORCE_EQ( name_.empty(), false, paddle::platform::errors::PreconditionNotMet( "Need to SetName first, so that the corresponding tensor can " "be retrieved.")); PADDLE_ENFORCE_EQ(input_or_output_, true, paddle::platform::errors::PermissionDenied( "Can't reshape the output tensor, it is readonly")); auto *scope = static_cast(scope_); auto *var = scope->FindVar(name_); PADDLE_ENFORCE_NOT_NULL( var, paddle::platform::errors::PreconditionNotMet( "No tensor called [%s] in the runtime scope", name_)); paddle_infer::Strings *tensor = var->GetMutable(); tensor->resize(shape); } #define EAGER_GET_TENSOR(tensor_type) \ if (!tensor_) { \ tensor_ = FindTensor(); \ } \ auto *tensor = static_cast(tensor_); template T *Tensor::mutable_data(PlaceType place) { #ifdef PADDLE_WITH_ONNXRUNTIME if (is_ort_tensor_) { return ORTGetMutableData(); } #endif EAGER_GET_TENSOR(paddle::framework::LoDTensor); PADDLE_ENFORCE_GT( tensor->numel(), 0, paddle::platform::errors::PreconditionNotMet( "You should call Tensor::Reshape(const std::vector " "&shape)" "function before retrieving mutable_data from input tensor.")); switch (static_cast(place)) { case static_cast(PlaceType::kCPU): { return tensor->mutable_data(paddle::platform::CPUPlace()); } case static_cast(PlaceType::kGPU): { return tensor->mutable_data(paddle::platform::CUDAPlace(device_)); } case static_cast(PlaceType::kXPU): { return tensor->mutable_data(paddle::platform::XPUPlace(device_)); } case static_cast(PlaceType::kNPU): { return tensor->mutable_data(paddle::platform::NPUPlace(device_)); } default: PADDLE_THROW(paddle::platform::errors::Unavailable( "Only CPU / CUDA / XPU / NPU places is supported. The place `%d` is " "not supported.", static_cast(place))); break; } return nullptr; } template T *Tensor::data(PlaceType *place, int *size) const { EAGER_GET_TENSOR(paddle::framework::LoDTensor); auto *res = tensor->data(); if (paddle::platform::is_cpu_place(tensor->place())) { *place = PlaceType::kCPU; } else if (paddle::platform::is_gpu_place(tensor->place())) { *place = PlaceType::kGPU; } else if (paddle::platform::is_xpu_place(tensor->place())) { *place = PlaceType::kXPU; } else if (paddle::platform::is_npu_place(tensor->place())) { *place = PlaceType::kNPU; } else { *place = PlaceType::kUNK; } *size = tensor->numel(); return res; } DataType Tensor::type() const { #ifdef PADDLE_WITH_ONNXRUNTIME if (is_ort_tensor_) { return dtype_; } #endif EAGER_GET_TENSOR(paddle::framework::LoDTensor); auto type = paddle::framework::TransToProtoVarType(tensor->dtype()); if (type == paddle::framework::proto::VarType::FP32) { return DataType::FLOAT32; } else if (type == paddle::framework::proto::VarType::FP16) { return DataType::FLOAT16; } else if (type == paddle::framework::proto::VarType::INT64) { return DataType::INT64; } else if (type == paddle::framework::proto::VarType::INT32) { return DataType::INT32; } else if (type == paddle::framework::proto::VarType::UINT8) { return DataType::UINT8; } else if (type == paddle::framework::proto::VarType::INT8) { return DataType::INT8; } return DataType::FLOAT32; } PlaceType Tensor::place() const { return place_; } template void Tensor::CopyFromCpu(const T *data) { EAGER_GET_TENSOR(paddle::framework::LoDTensor); PADDLE_ENFORCE_GE(tensor->numel(), 0, paddle::platform::errors::PreconditionNotMet( "You should call Tensor::Reshape(const " "std::vector &shape)" "function before copying data from cpu.")); size_t ele_size = tensor->numel() * sizeof(T); if (place_ == PlaceType::kCPU) { auto *t_data = tensor->mutable_data(paddle::platform::CPUPlace()); std::memcpy(static_cast(t_data), data, ele_size); } else if (place_ == PlaceType::kGPU) { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) paddle::platform::DeviceContextPool &pool = paddle::platform::DeviceContextPool::Instance(); paddle::platform::CUDAPlace gpu_place(device_); auto *t_data = tensor->mutable_data(gpu_place); auto *dev_ctx = static_cast( pool.Get(gpu_place)); paddle::memory::Copy(gpu_place, static_cast(t_data), paddle::platform::CPUPlace(), data, ele_size, dev_ctx->stream()); #else PADDLE_THROW(paddle::platform::errors::Unavailable( "Can not create tensor with CUDA place because paddle is not compiled " "with CUDA.")); #endif } else if (place_ == PlaceType::kXPU) { #ifdef PADDLE_WITH_XPU paddle::platform::XPUPlace xpu_place(device_); auto *t_data = tensor->mutable_data(xpu_place); paddle::memory::Copy(xpu_place, static_cast(t_data), paddle::platform::CPUPlace(), data, ele_size); #else PADDLE_THROW(paddle::platform::errors::Unavailable( "Can not create tensor with XPU place because paddle is not compiled " "with XPU.")); #endif } else if (place_ == PlaceType::kNPU) { #ifdef PADDLE_WITH_ASCEND_CL paddle::platform::DeviceContextPool &pool = paddle::platform::DeviceContextPool::Instance(); paddle::platform::NPUPlace npu_place(device_); auto *t_data = tensor->mutable_data(npu_place); auto *dev_ctx = static_cast( pool.Get(npu_place)); paddle::memory::Copy(npu_place, static_cast(t_data), paddle::platform::CPUPlace(), data, ele_size, dev_ctx->stream()); #else PADDLE_THROW(paddle::platform::errors::Unavailable( "Can not create tensor with NPU place because paddle is not compiled " "with NPU.")); #endif } else { PADDLE_THROW(paddle::platform::errors::InvalidArgument( "The analysis predictor supports CPU, GPU, NPU and XPU now.")); } } template struct DataTypeInfo; template <> struct DataTypeInfo { paddle::experimental::DataType TYPE = paddle::experimental::DataType::FLOAT32; }; template <> struct DataTypeInfo { paddle::experimental::DataType TYPE = paddle::experimental::DataType::FLOAT16; }; template <> struct DataTypeInfo { paddle::experimental::DataType TYPE = paddle::experimental::DataType::INT64; }; template <> struct DataTypeInfo { paddle::experimental::DataType TYPE = paddle::experimental::DataType::INT8; }; template <> struct DataTypeInfo { paddle::experimental::DataType TYPE = paddle::experimental::DataType::UINT8; }; template <> struct DataTypeInfo { paddle::experimental::DataType TYPE = paddle::experimental::DataType::INT32; }; paddle::experimental::DataLayout LayoutConvert(DataLayout layout) { PADDLE_ENFORCE_EQ( layout, DataLayout::kNCHW, paddle::platform::errors::InvalidArgument("Only NCHW is supported now.")); return paddle::experimental::DataLayout::NCHW; } template void Tensor::ShareExternalData(const T *data, const std::vector &shape, PlaceType place, DataLayout layout) { EAGER_GET_TENSOR(paddle::framework::LoDTensor) size_t size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()) * sizeof(T); phi::DenseTensorMeta meta( DataTypeInfo().TYPE, phi::make_ddim(shape), LayoutConvert(layout)); if (place == PlaceType::kCPU) { phi::DenseTensor dtensor( std::make_shared( const_cast(data), size, paddle::platform::CPUPlace()), meta); *tensor = std::move(dtensor); } else if (place == PlaceType::kGPU) { phi::DenseTensor dtensor( std::make_shared( const_cast(data), size, paddle::platform::CUDAPlace(device_)), meta); *tensor = std::move(dtensor); } else { PADDLE_THROW(paddle::platform::errors::InvalidArgument( "PlaceType must be PlaceType::kCPU or PlaceType::kGPU.")); } } void Tensor::CopyStringsFromCpu(const paddle_infer::Strings *data) { EAGER_GET_TENSOR(paddle_infer::Strings); PADDLE_ENFORCE_GE(tensor->size(), 0, paddle::platform::errors::PreconditionNotMet( "You should call Tensor::Reshape(const " "std::size_t &shape)function before copying" "the string data from cpu.")); *tensor = *data; } template void Tensor::CopyToCpuImpl(T *data, void *exec_stream, CallbackFunc cb, void *cb_params) const { EAGER_GET_TENSOR(paddle::framework::LoDTensor); auto ele_num = tensor->numel(); auto *t_data = tensor->data(); auto t_place = tensor->place(); paddle::framework::Tensor out; auto mem_allocation = std::make_shared( static_cast(data), ele_num * sizeof(T), paddle::platform::CPUPlace()); out.ResetHolder(mem_allocation); if (paddle::platform::is_cpu_place(t_place)) { #ifdef PADDLE_WITH_MKLDNN if (tensor->layout() == paddle::framework::DataLayout::kMKLDNN) paddle::framework::innerTransDataLayoutFromMKLDNN( tensor->layout(), paddle::platform::MKLDNNDeviceContext::tls() .get_cur_paddle_data_layout(), *tensor, &out, paddle::platform::CPUPlace(), true); else std::memcpy(static_cast(data), t_data, ele_num * sizeof(T)); #else std::memcpy(static_cast(data), t_data, ele_num * sizeof(T)); #endif } else if (paddle::platform::is_ipu_place(t_place)) { #ifdef PADDLE_WITH_IPU std::memcpy(static_cast(data), t_data, ele_num * sizeof(T)); #else PADDLE_THROW(paddle::platform::errors::Unavailable( "Can not create tensor with IPU place because paddle is not compiled " "with IPU.")); #endif } else if (place_ == PlaceType::kGPU) { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) paddle::platform::DeviceContextPool &pool = paddle::platform::DeviceContextPool::Instance(); auto gpu_place = t_place; auto *dev_ctx = static_cast( pool.Get(gpu_place)); paddle::memory::Copy(paddle::platform::CPUPlace(), static_cast(data), gpu_place, t_data, ele_num * sizeof(T), dev_ctx->stream()); #ifdef PADDLE_WITH_HIP hipStreamSynchronize(dev_ctx->stream()); #else // async, return stream if (nullptr != exec_stream) { *(static_cast(exec_stream)) = dev_ctx->stream(); // async with callback } else if (cb) { cudaLaunchHostFunc(dev_ctx->stream(), cb, cb_params); // sync } else { cudaStreamSynchronize(dev_ctx->stream()); } #endif #else PADDLE_THROW(paddle::platform::errors::Unavailable( "Can not create tensor with CUDA place because paddle is not compiled " "with CUDA.")); #endif } else if (place_ == PlaceType::kXPU) { #ifdef PADDLE_WITH_XPU auto xpu_place = t_place; paddle::memory::Copy(paddle::platform::CPUPlace(), static_cast(data), xpu_place, t_data, ele_num * sizeof(T)); #else PADDLE_THROW(paddle::platform::errors::Unavailable( "Can not create tensor with XPU place because paddle is not compiled " "with XPU.")); #endif } else if (place_ == PlaceType::kNPU) { #ifdef PADDLE_WITH_ASCEND_CL paddle::platform::DeviceContextPool &pool = paddle::platform::DeviceContextPool::Instance(); auto npu_place = t_place; auto *dev_ctx = static_cast( pool.Get(npu_place)); paddle::memory::Copy(paddle::platform::CPUPlace(), static_cast(data), npu_place, t_data, ele_num * sizeof(T), dev_ctx->stream()); paddle::platform::NPUStreamSync(dev_ctx->stream()); #else PADDLE_THROW(paddle::platform::errors::Unavailable( "Can not create tensor with NPU place because paddle is not compiled " "with NPU.")); #endif } else { PADDLE_THROW(paddle::platform::errors::InvalidArgument( "The analysis predictor supports CPU, GPU, NPU and XPU now.")); } } template void Tensor::CopyToCpu(T *data) const { #ifdef PADDLE_WITH_ONNXRUNTIME if (is_ort_tensor_) { ORTCopyToCpu(data); return; } #endif CopyToCpuImpl(data, nullptr, nullptr, nullptr); } template void Tensor::CopyToCpuAsync(T *data, void *exec_stream) const { CopyToCpuImpl(data, exec_stream, nullptr, nullptr); } template void Tensor::CopyToCpuAsync(T *data, CallbackFunc cb, void *cb_params) const { CopyToCpuImpl(data, nullptr, cb, cb_params); } template PD_INFER_DECL void Tensor::CopyFromCpu(const float *data); template PD_INFER_DECL void Tensor::CopyFromCpu(const int64_t *data); template PD_INFER_DECL void Tensor::CopyFromCpu(const int32_t *data); template PD_INFER_DECL void Tensor::CopyFromCpu(const uint8_t *data); template PD_INFER_DECL void Tensor::CopyFromCpu(const int8_t *data); template PD_INFER_DECL void Tensor::CopyFromCpu(const float16 *data); template PD_INFER_DECL void Tensor::ShareExternalData( const float *data, const std::vector &shape, PlaceType place, DataLayout layout); template PD_INFER_DECL void Tensor::ShareExternalData( const int64_t *data, const std::vector &shape, PlaceType place, DataLayout layout); template PD_INFER_DECL void Tensor::ShareExternalData( const int32_t *data, const std::vector &shape, PlaceType place, DataLayout layout); template PD_INFER_DECL void Tensor::ShareExternalData( const uint8_t *data, const std::vector &shape, PlaceType place, DataLayout layout); template PD_INFER_DECL void Tensor::ShareExternalData( const int8_t *data, const std::vector &shape, PlaceType place, DataLayout layout); template PD_INFER_DECL void Tensor::ShareExternalData( const float16 *data, const std::vector &shape, PlaceType place, DataLayout layout); template PD_INFER_DECL void Tensor::CopyToCpu(float *data) const; template PD_INFER_DECL void Tensor::CopyToCpu(int64_t *data) const; template PD_INFER_DECL void Tensor::CopyToCpu(int32_t *data) const; template PD_INFER_DECL void Tensor::CopyToCpu(uint8_t *data) const; template PD_INFER_DECL void Tensor::CopyToCpu(int8_t *data) const; template PD_INFER_DECL void Tensor::CopyToCpu(float16 *data) const; template PD_INFER_DECL void Tensor::CopyToCpuImpl(float *data, void *exec_stream, CallbackFunc cb, void *cb_params) const; template PD_INFER_DECL void Tensor::CopyToCpuImpl( int64_t *data, void *exec_stream, CallbackFunc cb, void *cb_params) const; template PD_INFER_DECL void Tensor::CopyToCpuImpl( int32_t *data, void *exec_stream, CallbackFunc cb, void *cb_params) const; template PD_INFER_DECL void Tensor::CopyToCpuImpl( uint8_t *data, void *exec_stream, CallbackFunc cb, void *cb_params) const; template PD_INFER_DECL void Tensor::CopyToCpuImpl( int8_t *data, void *exec_stream, CallbackFunc cb, void *cb_params) const; template PD_INFER_DECL void Tensor::CopyToCpuImpl( float16 *data, void *exec_stream, CallbackFunc cb, void *cb_params) const; template PD_INFER_DECL void Tensor::CopyToCpuAsync( float *data, void *exec_stream) const; template PD_INFER_DECL void Tensor::CopyToCpuAsync( int64_t *data, void *exec_stream) const; template PD_INFER_DECL void Tensor::CopyToCpuAsync( int32_t *data, void *exec_stream) const; template PD_INFER_DECL void Tensor::CopyToCpuAsync( uint8_t *data, void *exec_stream) const; template PD_INFER_DECL void Tensor::CopyToCpuAsync( int8_t *data, void *exec_stream) const; template PD_INFER_DECL void Tensor::CopyToCpuAsync( float16 *data, void *exec_stream) const; template PD_INFER_DECL void Tensor::CopyToCpuAsync( float *data, CallbackFunc cb, void *cb_params) const; template PD_INFER_DECL void Tensor::CopyToCpuAsync( int64_t *data, CallbackFunc cb, void *cb_params) const; template PD_INFER_DECL void Tensor::CopyToCpuAsync( int32_t *data, CallbackFunc cb, void *cb_params) const; template PD_INFER_DECL void Tensor::CopyToCpuAsync( uint8_t *data, CallbackFunc cb, void *cb_params) const; template PD_INFER_DECL void Tensor::CopyToCpuAsync( int8_t *data, CallbackFunc cb, void *cb_params) const; template PD_INFER_DECL void Tensor::CopyToCpuAsync( float16 *data, CallbackFunc cb, void *cb_params) const; template PD_INFER_DECL float *Tensor::data(PlaceType *place, int *size) const; template PD_INFER_DECL int64_t *Tensor::data(PlaceType *place, int *size) const; template PD_INFER_DECL int32_t *Tensor::data(PlaceType *place, int *size) const; template PD_INFER_DECL uint8_t *Tensor::data(PlaceType *place, int *size) const; template PD_INFER_DECL int8_t *Tensor::data(PlaceType *place, int *size) const; template PD_INFER_DECL float16 *Tensor::data(PlaceType *place, int *size) const; template PD_INFER_DECL float *Tensor::mutable_data(PlaceType place); template PD_INFER_DECL int64_t *Tensor::mutable_data(PlaceType place); template PD_INFER_DECL int32_t *Tensor::mutable_data(PlaceType place); template PD_INFER_DECL uint8_t *Tensor::mutable_data(PlaceType place); template PD_INFER_DECL int8_t *Tensor::mutable_data(PlaceType place); template PD_INFER_DECL float16 *Tensor::mutable_data(PlaceType place); Tensor::Tensor(void *scope) : scope_{scope} {} template void *Tensor::FindTensor() const { PADDLE_ENFORCE_EQ( name_.empty(), false, paddle::platform::errors::PreconditionNotMet( "Need to SetName first, so that the corresponding tensor can " "be retrieved.")); auto *scope = static_cast(scope_); auto *var = scope->FindVar(name_); PADDLE_ENFORCE_NOT_NULL( var, paddle::platform::errors::PreconditionNotMet( "No tensor called [%s] in the runtime scope", name_)); auto *tensor = var->GetMutable(); return tensor; } std::vector Tensor::shape() const { #ifdef PADDLE_WITH_ONNXRUNTIME if (is_ort_tensor_) { std::vector shape; // input handle if (idx_ < 0) { shape.assign(shape_.begin(), shape_.end()); } else { // output handle auto binding = binding_.lock(); PADDLE_ENFORCE_NOT_NULL(binding, paddle::platform::errors::PreconditionNotMet( "output tensor [%s] no binding ptr", name_)); std::vector outputs = binding->GetOutputValues(); Ort::Value &value = outputs[idx_]; auto info = value.GetTensorTypeAndShapeInfo(); auto ort_shape = info.GetShape(); shape.assign(ort_shape.begin(), ort_shape.end()); } return shape; } #endif EAGER_GET_TENSOR(paddle::framework::LoDTensor); PADDLE_ENFORCE_NOT_NULL( tensor_, paddle::platform::errors::PreconditionNotMet( "Not found tensor called %s in the scope", name_)); // mkldnn may does layout transform internally, so need to reorder before // return #ifdef PADDLE_WITH_MKLDNN if (tensor->layout() == paddle::framework::DataLayout::kMKLDNN) { paddle::framework::DataLayout out_layout = paddle::platform::MKLDNNDeviceContext::tls() .get_cur_paddle_data_layout(); // Set default as NCHW in case not specified out_layout = out_layout == paddle::framework::DataLayout::kAnyLayout ? paddle::framework::DataLayout::kNCHW : out_layout; // In these data layouts, channel dimension is either on 2nd position: nChw // or // at last nhwC, so for dim==2 these layouts are the same and nothing should // be done. Similarly for dim==1 when you have just one possible // combination. if (tensor->dims().size() < 3) return phi::vectorize(tensor->dims()); if (out_layout == paddle::framework::DataLayout::kNHWC || out_layout == paddle::framework::DataLayout::kNDHWC) { auto dims = phi::vectorize(tensor->dims()); std::rotate(dims.begin() + 1, dims.begin() + 2, dims.end()); return dims; } else { return phi::vectorize(tensor->dims()); } } #endif return phi::vectorize(tensor->dims()); } void Tensor::SetLoD(const std::vector> &x) { EAGER_GET_TENSOR(paddle::framework::LoDTensor); paddle::framework::LoD lod; for (auto &level : x) { lod.emplace_back(level); } tensor->set_lod(lod); } std::vector> Tensor::lod() const { EAGER_GET_TENSOR(paddle::framework::LoDTensor); std::vector> res; for (auto &level : tensor->lod()) { res.emplace_back(level); } return res; } void Tensor::SetName(const std::string &name) { name_ = name; } const std::string &Tensor::name() const { return name_; } void Tensor::SetPlace(PlaceType place, int device) { place_ = place; device_ = device; } #ifdef PADDLE_WITH_ONNXRUNTIME void Tensor::SetOrtMark(bool is_ort_tensor) { is_ort_tensor_ = is_ort_tensor; } void Tensor::SetOrtBinding(const std::shared_ptr binding) { binding_ = binding; } template T *Tensor::ORTGetMutableData() { auto binding = binding_.lock(); PADDLE_ENFORCE_NOT_NULL(binding, paddle::platform::errors::PreconditionNotMet( "output tensor [%s] no binding ptr", name_)); std::vector outputs = binding->GetOutputValues(); Ort::Value &value = outputs[idx_]; return value.GetTensorMutableData(); } template void Tensor::ORTCopyToCpu(T *data) const { auto binding = binding_.lock(); PADDLE_ENFORCE_NOT_NULL(binding, paddle::platform::errors::PreconditionNotMet( "output tensor [%s] no binding ptr", name_)); std::vector outputs = binding->GetOutputValues(); Ort::Value &value = outputs[idx_]; auto info = value.GetTensorTypeAndShapeInfo(); size_t size = info.GetElementCount() * sizeof(T); if (place_ == PlaceType::kCPU) { std::memcpy(static_cast(data), value.GetTensorData(), size); } else { PADDLE_THROW(paddle::platform::errors::Unavailable( "CopyToCpu error.The current ONNXRuntime backend doesn't support " "GPU.")); } } template void Tensor::ORTCopyToCpu(float *data) const; template void Tensor::ORTCopyToCpu(int32_t *data) const; template void Tensor::ORTCopyToCpu(uint8_t *data) const; template void Tensor::ORTCopyToCpu(int8_t *data) const; template void Tensor::ORTCopyToCpu(float16 *data) const; #endif } // namespace paddle_infer