未验证 提交 10fd4a95 编写于 作者: Y Yuanle Liu 提交者: GitHub

[Paddle Inference] Predictor support paddle::Tensor (#50445)

上级 259b0aad
...@@ -17,8 +17,6 @@ set(PADDLE_INFERENCE_INSTALL_DIR ...@@ -17,8 +17,6 @@ set(PADDLE_INFERENCE_INSTALL_DIR
function(phi_header_path_compat TARGET_PATH) function(phi_header_path_compat TARGET_PATH)
message(STATUS "phi header path compat processing: ${TARGET_PATH}") message(STATUS "phi header path compat processing: ${TARGET_PATH}")
string(FIND ${TARGET_PATH} "experimental" pos)
if(pos GREATER 1)
file(GLOB HEADERS "${TARGET_PATH}/*" "*.h") file(GLOB HEADERS "${TARGET_PATH}/*" "*.h")
foreach(header ${HEADERS}) foreach(header ${HEADERS})
if(${header} MATCHES ".*.h$") if(${header} MATCHES ".*.h$")
...@@ -34,7 +32,6 @@ function(phi_header_path_compat TARGET_PATH) ...@@ -34,7 +32,6 @@ function(phi_header_path_compat TARGET_PATH)
message(STATUS "phi header path compat processing complete: ${header}") message(STATUS "phi header path compat processing complete: ${header}")
endif() endif()
endforeach() endforeach()
endif()
endfunction() endfunction()
phi_header_path_compat( phi_header_path_compat(
...@@ -51,6 +48,7 @@ phi_header_path_compat( ...@@ -51,6 +48,7 @@ phi_header_path_compat(
${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/phi/common) ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/phi/common)
phi_header_path_compat( phi_header_path_compat(
${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/phi/core) ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/phi/core)
phi_header_path_compat(${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/)
# In order to be compatible with the original behavior, the header file name needs to be changed # In order to be compatible with the original behavior, the header file name needs to be changed
file(RENAME file(RENAME
......
...@@ -95,7 +95,7 @@ phi::DenseTensor& GetVariableTensor(const Scope& scope, ...@@ -95,7 +95,7 @@ phi::DenseTensor& GetVariableTensor(const Scope& scope,
PADDLE_ENFORCE_EQ(var->IsType<phi::DenseTensor>(), PADDLE_ENFORCE_EQ(var->IsType<phi::DenseTensor>(),
true, true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Only support lod tensor in GetVariableTensor now.")); "Only support DenseTensor in GetVariableTensor now."));
return *var->GetMutable<phi::DenseTensor>(); return *var->GetMutable<phi::DenseTensor>();
} }
......
...@@ -155,9 +155,8 @@ phi::Backend ConvertBackend(paddle_infer::PlaceType backend) { ...@@ -155,9 +155,8 @@ phi::Backend ConvertBackend(paddle_infer::PlaceType backend) {
return phi::Backend::CPU; return phi::Backend::CPU;
} }
} }
} // namespace
bool PaddleTensorToLoDTensor(const PaddleTensor &pt, bool PaddleTensorToDenseTensor(const PaddleTensor &pt,
phi::DenseTensor *t, phi::DenseTensor *t,
const platform::Place &place) { const platform::Place &place) {
framework::DDim ddim = phi::make_ddim(pt.shape); framework::DDim ddim = phi::make_ddim(pt.shape);
...@@ -270,6 +269,7 @@ bool PaddleTensorToLoDTensor(const PaddleTensor &pt, ...@@ -270,6 +269,7 @@ bool PaddleTensorToLoDTensor(const PaddleTensor &pt,
t->set_lod(lod); t->set_lod(lod);
return true; return true;
} }
} // namespace
bool AnalysisPredictor::Init( bool AnalysisPredictor::Init(
const std::shared_ptr<framework::Scope> &parent_scope, const std::shared_ptr<framework::Scope> &parent_scope,
...@@ -919,6 +919,17 @@ void AnalysisPredictor::MkldnnPreSet(const std::vector<PaddleTensor> &inputs) { ...@@ -919,6 +919,17 @@ void AnalysisPredictor::MkldnnPreSet(const std::vector<PaddleTensor> &inputs) {
#endif #endif
} }
void AnalysisPredictor::MkldnnPreSet(
const std::vector<paddle::Tensor> &inputs) {
#ifdef PADDLE_WITH_MKLDNN
std::vector<std::vector<int>> inputs_shape;
for (size_t i = 0; i < inputs.size(); ++i) {
inputs_shape.emplace_back(phi::vectorize<int>(inputs[i].dims()));
}
MkldnnPreSet(inputs_shape);
#endif
}
void AnalysisPredictor::MkldnnPreSet( void AnalysisPredictor::MkldnnPreSet(
const std::vector<std::vector<int>> &inputs_shape) { const std::vector<std::vector<int>> &inputs_shape) {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
...@@ -1033,6 +1044,70 @@ bool AnalysisPredictor::Run(const std::vector<PaddleTensor> &inputs, ...@@ -1033,6 +1044,70 @@ bool AnalysisPredictor::Run(const std::vector<PaddleTensor> &inputs,
return true; return true;
} }
bool AnalysisPredictor::Run(const std::vector<paddle::Tensor> &inputs,
std::vector<paddle::Tensor> *outputs) {
inference::DisplayMemoryInfo(place_, "before run");
paddle::platform::SetNumThreads(config_.cpu_math_library_num_threads());
#ifdef PADDLE_WITH_MKLDNN
if (config_.use_mkldnn_) MkldnnPreSet(inputs);
#endif
VLOG(3) << "predict start";
// set feed variable
framework::Scope *scope = sub_scope_ ? sub_scope_ : scope_.get();
PADDLE_ENFORCE_NOT_NULL(
scope,
platform::errors::PreconditionNotMet("The scope should not be nullptr."));
if (!SetFeed(inputs, scope)) {
LOG(ERROR) << "fail to set feed";
return false;
}
#ifdef PADDLE_WITH_TENSORRT
if (config_.tensorrt_engine_enabled()) {
inference::tensorrt::TensorRTEngine::predictor_id_per_thread =
predictor_id_;
VLOG(3) << "thread_local var predictor_id in TensorRTEngine is set to: "
<< inference::tensorrt::TensorRTEngine::predictor_id_per_thread;
}
#endif
// Run the inference program
// if share variables, we need not create variables
executor_->Run();
inference::DisplayMemoryInfo(place_, "after run");
// get fetch variable
if (!GetFetch(outputs, scope)) {
LOG(ERROR) << "fail to get fetches";
return false;
}
// All the containers in the scope will be hold in inference, but the
// operators assume that the container will be reset after each batch.
// Here is a bugfix, collect all the container variables, and reset then to a
// bool; the next time, the operator will call MutableData and construct a new
// container again, so that the container will be empty for each batch.
if (sub_scope_) {
tensor_array_batch_cleaner_.CollectNoTensorVars(sub_scope_);
}
tensor_array_batch_cleaner_.ResetNoTensorVars();
// recover the cpu_math_library_num_threads to 1, in order to avoid thread
// conflict when integrating it into deployment service.
paddle::platform::SetNumThreads(1);
#ifdef PADDLE_WITH_MKLDNN
if (config_.use_mkldnn_) MkldnnPostReset();
#endif
#if defined(PADDLE_WITH_MKLML)
// Frees unused memory allocated by the Intel® MKL Memory Allocator to
// avoid memory leak. See:
// https://software.intel.com/en-us/mkl-developer-reference-c-mkl-free-buffers
platform::dynload::MKL_Free_Buffers();
#endif
return true;
}
bool AnalysisPredictor::SetFeed(const std::vector<PaddleTensor> &inputs, bool AnalysisPredictor::SetFeed(const std::vector<PaddleTensor> &inputs,
framework::Scope *scope) { framework::Scope *scope) {
VLOG(3) << "Predictor::set_feed"; VLOG(3) << "Predictor::set_feed";
...@@ -1047,7 +1122,7 @@ bool AnalysisPredictor::SetFeed(const std::vector<PaddleTensor> &inputs, ...@@ -1047,7 +1122,7 @@ bool AnalysisPredictor::SetFeed(const std::vector<PaddleTensor> &inputs,
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
phi::DenseTensor *input = &feed_tensors_[i]; phi::DenseTensor *input = &feed_tensors_[i];
if (!PaddleTensorToLoDTensor(inputs[i], input, place_)) { if (!PaddleTensorToDenseTensor(inputs[i], input, place_)) {
return false; return false;
} }
int idx = -1; int idx = -1;
...@@ -1061,7 +1136,41 @@ bool AnalysisPredictor::SetFeed(const std::vector<PaddleTensor> &inputs, ...@@ -1061,7 +1136,41 @@ bool AnalysisPredictor::SetFeed(const std::vector<PaddleTensor> &inputs,
} else { } else {
idx = PADDLE_GET_CONST(int, feeds_[i]->GetAttr("col")); idx = PADDLE_GET_CONST(int, feeds_[i]->GetAttr("col"));
} }
framework::SetFeedVariable(scope, *input, "feed", idx); framework::SetFeedVariable(scope, *input, framework::kFeedOpType, idx);
}
return true;
}
bool AnalysisPredictor::SetFeed(const std::vector<paddle::Tensor> &inputs,
framework::Scope *scope) {
VLOG(3) << "Predictor::set_feed";
PADDLE_ENFORCE_EQ(inputs.size(),
feeds_.size(),
platform::errors::InvalidArgument(
"wrong feed input size, need %d but get %d.",
feeds_.size(),
inputs.size()));
for (size_t i = 0; i < inputs.size(); ++i) {
PADDLE_ENFORCE_EQ(inputs[i].initialized(),
true,
paddle::platform::errors::InvalidArgument(
"The input Tensor expected to be initialized."));
}
if (std::all_of(inputs.cbegin(), inputs.cend(), [&](const paddle::Tensor &t) {
return !t.name().empty() && feed_names_.count(t.name());
})) {
for (size_t i = 0; i < inputs.size(); ++i) {
auto &t = framework::GetVariableTensor(*scope, inputs[i].name());
t.ShareDataWith(
*std::dynamic_pointer_cast<phi::DenseTensor>(inputs[i].impl()));
}
} else {
for (size_t i = 0; i < inputs.size(); ++i) {
auto &t = framework::GetVariableTensor(*scope, idx2feeds_[i]);
t.ShareDataWith(
*std::dynamic_pointer_cast<phi::DenseTensor>(inputs[i].impl()));
}
} }
return true; return true;
} }
...@@ -1100,7 +1209,7 @@ bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs, ...@@ -1100,7 +1209,7 @@ bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs,
idx, idx,
i)); i));
framework::FetchType &fetch_var = framework::FetchType &fetch_var =
framework::GetFetchVariable(*scope, "fetch", idx); framework::GetFetchVariable(*scope, framework::kFetchOpType, idx);
auto &fetch = PADDLE_GET(phi::DenseTensor, fetch_var); auto &fetch = PADDLE_GET(phi::DenseTensor, fetch_var);
auto type = framework::TransToProtoVarType(fetch.dtype()); auto type = framework::TransToProtoVarType(fetch.dtype());
auto output = &(outputs->at(i)); auto output = &(outputs->at(i));
...@@ -1125,6 +1234,19 @@ bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs, ...@@ -1125,6 +1234,19 @@ bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs,
return true; return true;
} }
bool AnalysisPredictor::GetFetch(std::vector<paddle::Tensor> *outputs,
framework::Scope *scope) {
VLOG(3) << "Predictor::get_fetch";
outputs->resize(fetches_.size());
for (size_t i = 0; i < fetches_.size(); ++i) {
auto const &name = idx2fetches_[i];
auto &t = framework::GetVariableTensor(*scope, name);
(*outputs)[i] =
std::move(paddle::Tensor(std::make_shared<phi::DenseTensor>(t), name));
}
return true;
}
void AnalysisPredictor::PrepareArgument() { void AnalysisPredictor::PrepareArgument() {
VLOG(3) << "AnalysisPredictor::PrepareArgument"; VLOG(3) << "AnalysisPredictor::PrepareArgument";
// Init std::unique_ptr argument_. // Init std::unique_ptr argument_.
...@@ -1579,7 +1701,7 @@ void AnalysisPredictor::PrepareFeedFetch() { ...@@ -1579,7 +1701,7 @@ void AnalysisPredictor::PrepareFeedFetch() {
"The sub_scope should not be nullptr.")); "The sub_scope should not be nullptr."));
CreateFeedFetchVar(sub_scope_); CreateFeedFetchVar(sub_scope_);
for (auto *op : inference_program_->Block(0).AllOps()) { for (auto *op : inference_program_->Block(0).AllOps()) {
if (op->Type() == "feed") { if (op->Type() == framework::kFeedOpType) {
int idx = PADDLE_GET_CONST(int, op->GetAttr("col")); int idx = PADDLE_GET_CONST(int, op->GetAttr("col"));
if (feeds_.size() <= static_cast<size_t>(idx)) { if (feeds_.size() <= static_cast<size_t>(idx)) {
feeds_.resize(idx + 1); feeds_.resize(idx + 1);
...@@ -1587,7 +1709,7 @@ void AnalysisPredictor::PrepareFeedFetch() { ...@@ -1587,7 +1709,7 @@ void AnalysisPredictor::PrepareFeedFetch() {
feeds_[idx] = op; feeds_[idx] = op;
feed_names_[op->Output("Out")[0]] = idx; feed_names_[op->Output("Out")[0]] = idx;
idx2feeds_[idx] = op->Output("Out")[0]; idx2feeds_[idx] = op->Output("Out")[0];
} else if (op->Type() == "fetch") { } else if (op->Type() == framework::kFetchOpType) {
int idx = PADDLE_GET_CONST(int, op->GetAttr("col")); int idx = PADDLE_GET_CONST(int, op->GetAttr("col"));
if (fetches_.size() <= static_cast<size_t>(idx)) { if (fetches_.size() <= static_cast<size_t>(idx)) {
fetches_.resize(idx + 1); fetches_.resize(idx + 1);
...@@ -1602,9 +1724,9 @@ void AnalysisPredictor::CreateFeedFetchVar(framework::Scope *scope) { ...@@ -1602,9 +1724,9 @@ void AnalysisPredictor::CreateFeedFetchVar(framework::Scope *scope) {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
scope, scope,
platform::errors::InvalidArgument("The scope should not be nullptr.")); platform::errors::InvalidArgument("The scope should not be nullptr."));
auto *var = scope->Var("feed"); auto *var = scope->Var(framework::kFeedOpType);
var->GetMutable<framework::FeedList>(); var->GetMutable<framework::FeedList>();
var = scope->Var("fetch"); var = scope->Var(framework::kFetchOpType);
var->GetMutable<framework::FetchList>(); var->GetMutable<framework::FetchList>();
} }
...@@ -2186,7 +2308,7 @@ void AnalysisPredictor::ClearIntermediateTensor() { ...@@ -2186,7 +2308,7 @@ void AnalysisPredictor::ClearIntermediateTensor() {
const std::string name = var->Name(); const std::string name = var->Name();
auto *variable = executor_->GetScope()->FindVar(name); auto *variable = executor_->GetScope()->FindVar(name);
if (variable != nullptr && variable->IsType<phi::DenseTensor>() && if (variable != nullptr && variable->IsType<phi::DenseTensor>() &&
name != "feed" && name != "fetch") { name != framework::kFeedOpType && name != framework::kFetchOpType) {
VLOG(3) << "Clear Intermediate Tensor: " << name; VLOG(3) << "Clear Intermediate Tensor: " << name;
auto *t = variable->GetMutable<phi::DenseTensor>(); auto *t = variable->GetMutable<phi::DenseTensor>();
t->clear(); t->clear();
...@@ -2653,6 +2775,11 @@ std::map<std::string, DataType> Predictor::GetOutputTypes() { ...@@ -2653,6 +2775,11 @@ std::map<std::string, DataType> Predictor::GetOutputTypes() {
bool Predictor::Run() { return predictor_->ZeroCopyRun(); } bool Predictor::Run() { return predictor_->ZeroCopyRun(); }
bool Predictor::Run(const std::vector<paddle::Tensor> &inputs,
std::vector<paddle::Tensor> *outputs) {
return predictor_->Run(inputs, outputs);
}
std::unique_ptr<Predictor> Predictor::Clone(void *stream) { std::unique_ptr<Predictor> Predictor::Clone(void *stream) {
auto analysis_pred = predictor_->Clone(stream); auto analysis_pred = predictor_->Clone(stream);
std::unique_ptr<Predictor> pred(new Predictor(std::move(analysis_pred))); std::unique_ptr<Predictor> pred(new Predictor(std::move(analysis_pred)));
......
...@@ -31,15 +31,14 @@ ...@@ -31,15 +31,14 @@
#include "paddle/fluid/inference/api/paddle_inference_api.h" #include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/api/resource_manager.h" #include "paddle/fluid/inference/api/resource_manager.h"
#include "paddle/fluid/platform/device/gpu/gpu_types.h" #include "paddle/fluid/platform/device/gpu/gpu_types.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/printf.h"
#include "paddle/phi/core/dense_tensor.h"
#ifdef PADDLE_WITH_TESTING #ifdef PADDLE_WITH_TESTING
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <gtest/gtest_prod.h> #include <gtest/gtest_prod.h>
#endif #endif
namespace paddle_infer { namespace paddle_infer {
using float16 = paddle::platform::float16;
namespace experimental { namespace experimental {
class InternalUtils; class InternalUtils;
}; };
...@@ -150,6 +149,16 @@ class AnalysisPredictor : public PaddlePredictor { ...@@ -150,6 +149,16 @@ class AnalysisPredictor : public PaddlePredictor {
std::vector<PaddleTensor> *output_data, std::vector<PaddleTensor> *output_data,
int batch_size = -1) override; int batch_size = -1) override;
///
/// \brief Run the prediction engine (Recommended).
///
/// \param[in] inputs input tensors
/// \param[out] outputs output tensors
/// \return Whether the function executed successfully
///
bool Run(const std::vector<paddle::Tensor> &inputs,
std::vector<paddle::Tensor> *outputs) override;
/// ///
/// \brief Get the input names /// \brief Get the input names
/// ///
...@@ -378,6 +387,17 @@ class AnalysisPredictor : public PaddlePredictor { ...@@ -378,6 +387,17 @@ class AnalysisPredictor : public PaddlePredictor {
/// ///
bool SetFeed(const std::vector<PaddleTensor> &input_datas, bool SetFeed(const std::vector<PaddleTensor> &input_datas,
framework::Scope *scope); framework::Scope *scope);
///
/// \brief Prepare input data, only used in Run()
///
/// \param[in] inputs inpute tensors
/// \param[in] scope the scope used by predictor
/// \return Whether the function executed successfully
///
bool SetFeed(const std::vector<paddle::Tensor> &inputs,
framework::Scope *scope);
/// ///
/// \brief Get the output data, only used in Run() /// \brief Get the output data, only used in Run()
/// ///
...@@ -387,6 +407,16 @@ class AnalysisPredictor : public PaddlePredictor { ...@@ -387,6 +407,16 @@ class AnalysisPredictor : public PaddlePredictor {
/// ///
bool GetFetch(std::vector<PaddleTensor> *output_data, bool GetFetch(std::vector<PaddleTensor> *output_data,
framework::Scope *scope); framework::Scope *scope);
///
/// \brief Get the output data, only used in Run()
///
/// \param[out] outputs output tensors
/// \param[in] scope the scope used by predictor
/// \return Whether the function executed successfully
///
bool GetFetch(std::vector<paddle::Tensor> *outputs, framework::Scope *scope);
/// ///
/// \brief Get the output data, only used in GetFetch() /// \brief Get the output data, only used in GetFetch()
/// ///
...@@ -404,6 +434,14 @@ class AnalysisPredictor : public PaddlePredictor { ...@@ -404,6 +434,14 @@ class AnalysisPredictor : public PaddlePredictor {
/// \param[in] inputs tensors /// \param[in] inputs tensors
/// ///
void MkldnnPreSet(const std::vector<PaddleTensor> &inputs); void MkldnnPreSet(const std::vector<PaddleTensor> &inputs);
///
/// \brief PreSet for Mkldnn multi-thread and dynamic shape input.
///
/// Used in AnalysisPredictor::Run().
///
/// \param[in] inputs tensors
///
void MkldnnPreSet(const std::vector<paddle::Tensor> &inputs);
/// ///
/// \brief PreSet for Mkldnn multi-thread and dynamic shape input. /// \brief PreSet for Mkldnn multi-thread and dynamic shape input.
......
...@@ -83,7 +83,7 @@ else() ...@@ -83,7 +83,7 @@ else()
if(WITH_MKL) if(WITH_MKL)
set(FLAG_OPENMP "-fopenmp") set(FLAG_OPENMP "-fopenmp")
endif() endif()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 ${FLAG_OPENMP}") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14 ${FLAG_OPENMP}")
endif() endif()
if(WITH_GPU) if(WITH_GPU)
......
...@@ -221,6 +221,16 @@ class PD_INFER_DECL PaddlePredictor { ...@@ -221,6 +221,16 @@ class PD_INFER_DECL PaddlePredictor {
std::vector<PaddleTensor>* output_data, std::vector<PaddleTensor>* output_data,
int batch_size = -1) = 0; int batch_size = -1) = 0;
/// \brief This interface takes input and runs the network (Recommended).
/// \param[in] inputs An list of Tensor as the input to the network.
/// \param[out] output_data Pointer to the tensor list, which holds the output
/// Tensor
/// \return Whether the run is successful
virtual bool Run(const std::vector<paddle::Tensor>& inputs,
std::vector<paddle::Tensor>* outputs) {
return false;
}
/// \brief Used to get the name of the network input. /// \brief Used to get the name of the network input.
/// Be inherited by AnalysisPredictor, Only used in ZeroCopy scenarios. /// Be inherited by AnalysisPredictor, Only used in ZeroCopy scenarios.
/// \return Input tensor names. /// \return Input tensor names.
......
...@@ -128,6 +128,17 @@ class PD_INFER_DECL Predictor { ...@@ -128,6 +128,17 @@ class PD_INFER_DECL Predictor {
/// ///
bool Run(); bool Run();
///
/// \brief Run the prediction engine (Recommended)
///
/// \param[in] inputs An list of Tensor as the input to the network.
/// \param[out] outputs Pointer to the tensor list, which holds the output
/// Tensor
///
/// \return Whether the run is successful
bool Run(const std::vector<paddle::Tensor>& inputs,
std::vector<paddle::Tensor>* outputs);
/// ///
/// \brief Get the output names /// \brief Get the output names
/// ///
......
...@@ -21,6 +21,8 @@ ...@@ -21,6 +21,8 @@
#include "paddle_infer_declare.h" // NOLINT #include "paddle_infer_declare.h" // NOLINT
#include "paddle/phi/api/include/tensor.h" // expose paddle::Tensor
#ifdef PADDLE_WITH_ONNXRUNTIME #ifdef PADDLE_WITH_ONNXRUNTIME
#include "onnxruntime_c_api.h" // NOLINT #include "onnxruntime_c_api.h" // NOLINT
#include "onnxruntime_cxx_api.h" // NOLINT #include "onnxruntime_cxx_api.h" // NOLINT
......
...@@ -22,11 +22,6 @@ ...@@ -22,11 +22,6 @@
namespace paddle { namespace paddle {
namespace jit { namespace jit {
static PaddleTensor DenseTensorToPaddleTensor(DenseTensor *t);
static bool PaddleTensorToDenseTensor(const PaddleTensor &pt,
DenseTensor *t,
const platform::Place &place);
PredictorEngine::PredictorEngine( PredictorEngine::PredictorEngine(
const std::shared_ptr<FunctionInfo> &info, const std::shared_ptr<FunctionInfo> &info,
const std::shared_ptr<VariableMap> &params_dict, const std::shared_ptr<VariableMap> &params_dict,
...@@ -52,6 +47,7 @@ PredictorEngine::PredictorEngine( ...@@ -52,6 +47,7 @@ PredictorEngine::PredictorEngine(
config.SetSkipLoadParams(true); config.SetSkipLoadParams(true);
config.SetApplyOptim(true); config.SetApplyOptim(true);
config.SwitchIrOptim(true); config.SwitchIrOptim(true);
config.SwitchUseFeedFetchOps(false);
predictor_.reset(new AnalysisPredictor(config)); predictor_.reset(new AnalysisPredictor(config));
...@@ -78,135 +74,15 @@ std::unique_ptr<BaseEngine> PredictorEngine::Clone(void *stream) { ...@@ -78,135 +74,15 @@ std::unique_ptr<BaseEngine> PredictorEngine::Clone(void *stream) {
std::vector<Tensor> PredictorEngine::operator()( std::vector<Tensor> PredictorEngine::operator()(
const std::vector<Tensor> &inputs) { const std::vector<Tensor> &inputs) {
auto dense_tensors = utils::ToDenseTensors(inputs); std::vector<Tensor> outputs;
return utils::ToTensors(this->operator()(dense_tensors)); predictor_->Run(inputs, &outputs);
}
std::vector<DenseTensor> PredictorEngine::operator()(
const std::vector<DenseTensor> &inputs) {
std::vector<PaddleTensor> pt_inputs;
std::vector<PaddleTensor> pt_outputs;
for (auto &t : inputs) {
auto non_const_t = const_cast<DenseTensor *>(&t);
pt_inputs.emplace_back(DenseTensorToPaddleTensor(non_const_t));
}
predictor_->Run(pt_inputs, &pt_outputs);
std::vector<DenseTensor> outputs;
for (auto &pt : pt_outputs) {
DenseTensor t;
PaddleTensorToDenseTensor(pt, &t, place_);
outputs.emplace_back(t);
}
return outputs; return outputs;
} }
static PaddleTensor DenseTensorToPaddleTensor(DenseTensor *t) { std::vector<DenseTensor> PredictorEngine::operator()(
PaddleTensor pt; const std::vector<DenseTensor> &inputs) {
switch (framework::TransToProtoVarType(t->dtype())) { return utils::ToDenseTensors(this->operator()(utils::ToTensors(inputs)));
case framework::proto::VarType::INT32: {
pt.data.Reset(t->data(), t->numel() * sizeof(int32_t));
pt.dtype = PaddleDType::INT32;
} break;
case framework::proto::VarType::INT64: {
pt.data.Reset(t->data(), t->numel() * sizeof(int64_t));
pt.dtype = PaddleDType::INT64;
} break;
case framework::proto::VarType::FP32: {
pt.data.Reset(t->data(), t->numel() * sizeof(float));
pt.dtype = PaddleDType::FLOAT32;
} break;
default:
PADDLE_THROW(
platform::errors::Unimplemented("Unsupported tensor date type. Now "
"only supports INT64, FP32, INT32."));
}
pt.shape = phi::vectorize<int>(t->dims());
return pt;
}
static bool PaddleTensorToDenseTensor(const PaddleTensor &pt,
DenseTensor *t,
const platform::Place &place) {
framework::DDim ddim = phi::make_ddim(pt.shape);
void *input_ptr;
switch (pt.dtype) {
case PaddleDType::INT64:
input_ptr = t->mutable_data<int64_t>(ddim, place);
break;
case PaddleDType::FLOAT32:
input_ptr = t->mutable_data<float>(ddim, place);
break;
case PaddleDType::INT32:
input_ptr = t->mutable_data<int32_t>(ddim, place);
break;
case PaddleDType::FLOAT16:
input_ptr = t->mutable_data<float16>(ddim, place);
break;
default:
LOG(ERROR) << "unsupported feed type " << pt.dtype;
return false;
}
PADDLE_ENFORCE_NOT_NULL(
input_ptr,
paddle::platform::errors::Fatal(
"Cannot convert to LoDTensor because LoDTensor creation failed."));
PADDLE_ENFORCE_NOT_NULL(
pt.data.data(),
paddle::platform::errors::InvalidArgument(
"The data contained in the input PaddleTensor is illegal."));
if (platform::is_cpu_place(place)) {
// TODO(panyx0718): Init LoDTensor from existing memcpy to save a copy.
std::memcpy(
static_cast<void *>(input_ptr), pt.data.data(), pt.data.length());
} else if (platform::is_ipu_place(place)) {
#ifdef PADDLE_WITH_IPU
std::memcpy(
static_cast<void *>(input_ptr), pt.data.data(), pt.data.length());
#else
PADDLE_THROW(paddle::platform::errors::Fatal(
"Not compile with WITH_IPU, should not reach here."));
#endif
} else if (platform::is_gpu_place(place)) {
PADDLE_ENFORCE_EQ(platform::is_xpu_place(place),
false,
platform::errors::InvalidArgument(
"Only one choice can be made between CPU and XPU."));
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto *dev_ctx = static_cast<const phi::GPUContext *>(pool.Get(place));
auto dst_gpu_place = place;
memory::Copy(dst_gpu_place,
static_cast<void *>(input_ptr),
platform::CPUPlace(),
pt.data.data(),
pt.data.length(),
dev_ctx->stream());
#else
PADDLE_THROW(paddle::platform::errors::Fatal(
"Not compile with CUDA, should not reach here."));
#endif
} else if (platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU
auto dst_xpu_place = place;
memory::Copy(dst_xpu_place,
static_cast<void *>(input_ptr),
platform::CPUPlace(),
pt.data.data(),
pt.data.length());
#else
PADDLE_THROW(paddle::platform::errors::Fatal(
"Not compile with XPU, should not reach here."));
#endif
} else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"The analysis predictor supports CPU, GPU and XPU now."));
}
return true;
} }
} // namespace jit } // namespace jit
......
...@@ -770,7 +770,11 @@ PyObject* ToPyObject(const std::vector<std::vector<size_t>>& value) { ...@@ -770,7 +770,11 @@ PyObject* ToPyObject(const std::vector<std::vector<size_t>>& value) {
PyObject* ToPyObject(const std::vector<paddle::Tensor>& value, PyObject* ToPyObject(const std::vector<paddle::Tensor>& value,
bool return_py_none_if_not_initialize) { bool return_py_none_if_not_initialize) {
// NOTE(liuyuanle): I encountered a bug(access violation) in windows. ref to
// https://stackoverflow.com/questions/55598839/how-to-fix-access-violation-error-when-returning-pyobject-from-c-function-usin
PyGILState_STATE gstate = PyGILState_Ensure();
PyObject* result = PyList_New((Py_ssize_t)value.size()); PyObject* result = PyList_New((Py_ssize_t)value.size());
PyGILState_Release(gstate);
for (size_t i = 0; i < value.size(); i++) { for (size_t i = 0; i < value.size(); i++) {
if (!value[i].initialized() && return_py_none_if_not_initialize) { if (!value[i].initialized() && return_py_none_if_not_initialize) {
......
...@@ -65,7 +65,7 @@ constexpr int NPY_UINT16_ = 4; ...@@ -65,7 +65,7 @@ constexpr int NPY_UINT16_ = 4;
// paddle::platform::float16 as numpy.float16. // paddle::platform::float16 as numpy.float16.
// Ref: https://github.com/pybind/pybind11/issues/1776 // Ref: https://github.com/pybind/pybind11/issues/1776
template <> template <>
struct npy_format_descriptor<paddle_infer::float16> { struct npy_format_descriptor<phi::dtype::float16> {
static py::dtype dtype() { static py::dtype dtype() {
handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_FLOAT16_); handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_FLOAT16_);
return reinterpret_borrow<py::dtype>(ptr); return reinterpret_borrow<py::dtype>(ptr);
...@@ -180,7 +180,7 @@ py::dtype PaddleDTypeToNumpyDType(PaddleDType dtype) { ...@@ -180,7 +180,7 @@ py::dtype PaddleDTypeToNumpyDType(PaddleDType dtype) {
dt = py::dtype::of<float>(); dt = py::dtype::of<float>();
break; break;
case PaddleDType::FLOAT16: case PaddleDType::FLOAT16:
dt = py::dtype::of<paddle_infer::float16>(); dt = py::dtype::of<phi::dtype::float16>();
break; break;
case PaddleDType::UINT8: case PaddleDType::UINT8:
dt = py::dtype::of<uint8_t>(); dt = py::dtype::of<uint8_t>();
...@@ -264,7 +264,7 @@ void PaddleInferShareExternalData(paddle_infer::Tensor &tensor, // NOLINT ...@@ -264,7 +264,7 @@ void PaddleInferShareExternalData(paddle_infer::Tensor &tensor, // NOLINT
ToPaddleInferPlace(input_tensor.place().GetType())); ToPaddleInferPlace(input_tensor.place().GetType()));
} else if (input_tensor.dtype() == phi::DataType::FLOAT16) { } else if (input_tensor.dtype() == phi::DataType::FLOAT16) {
tensor.ShareExternalData( tensor.ShareExternalData(
static_cast<paddle::platform::float16 *>(input_tensor.data()), static_cast<phi::dtype::float16 *>(input_tensor.data()),
shape, shape,
ToPaddleInferPlace(input_tensor.place().GetType())); ToPaddleInferPlace(input_tensor.place().GetType()));
} else if (input_tensor.dtype() == phi::DataType::INT32) { } else if (input_tensor.dtype() == phi::DataType::INT32) {
...@@ -353,7 +353,7 @@ size_t PaddleGetDTypeSize(PaddleDType dt) { ...@@ -353,7 +353,7 @@ size_t PaddleGetDTypeSize(PaddleDType dt) {
size = sizeof(float); size = sizeof(float);
break; break;
case PaddleDType::FLOAT16: case PaddleDType::FLOAT16:
size = sizeof(paddle_infer::float16); size = sizeof(phi::dtype::float16);
break; break;
case PaddleDType::INT8: case PaddleDType::INT8:
size = sizeof(int8_t); size = sizeof(int8_t);
...@@ -392,8 +392,8 @@ py::array ZeroCopyTensorToNumpy(ZeroCopyTensor &tensor) { // NOLINT ...@@ -392,8 +392,8 @@ py::array ZeroCopyTensorToNumpy(ZeroCopyTensor &tensor) { // NOLINT
tensor.copy_to_cpu<float>(static_cast<float *>(array.mutable_data())); tensor.copy_to_cpu<float>(static_cast<float *>(array.mutable_data()));
break; break;
case PaddleDType::FLOAT16: case PaddleDType::FLOAT16:
tensor.copy_to_cpu<paddle::platform::float16>( tensor.copy_to_cpu<phi::dtype::float16>(
static_cast<paddle::platform::float16 *>(array.mutable_data())); static_cast<phi::dtype::float16 *>(array.mutable_data()));
break; break;
case PaddleDType::UINT8: case PaddleDType::UINT8:
tensor.copy_to_cpu<uint8_t>(static_cast<uint8_t *>(array.mutable_data())); tensor.copy_to_cpu<uint8_t>(static_cast<uint8_t *>(array.mutable_data()));
...@@ -432,8 +432,8 @@ py::array PaddleInferTensorToNumpy(paddle_infer::Tensor &tensor) { // NOLINT ...@@ -432,8 +432,8 @@ py::array PaddleInferTensorToNumpy(paddle_infer::Tensor &tensor) { // NOLINT
tensor.CopyToCpu<float>(static_cast<float *>(array.mutable_data())); tensor.CopyToCpu<float>(static_cast<float *>(array.mutable_data()));
break; break;
case PaddleDType::FLOAT16: case PaddleDType::FLOAT16:
tensor.CopyToCpu<paddle::platform::float16>( tensor.CopyToCpu<phi::dtype::float16>(
static_cast<paddle::platform::float16 *>(array.mutable_data())); static_cast<phi::dtype::float16 *>(array.mutable_data()));
break; break;
case PaddleDType::UINT8: case PaddleDType::UINT8:
tensor.CopyToCpu(static_cast<uint8_t *>(array.mutable_data())); tensor.CopyToCpu(static_cast<uint8_t *>(array.mutable_data()));
...@@ -1062,6 +1062,16 @@ void BindPaddleInferPredictor(py::module *m) { ...@@ -1062,6 +1062,16 @@ void BindPaddleInferPredictor(py::module *m) {
.def("get_output_names", &paddle_infer::Predictor::GetOutputNames) .def("get_output_names", &paddle_infer::Predictor::GetOutputNames)
.def("get_input_handle", &paddle_infer::Predictor::GetInputHandle) .def("get_input_handle", &paddle_infer::Predictor::GetInputHandle)
.def("get_output_handle", &paddle_infer::Predictor::GetOutputHandle) .def("get_output_handle", &paddle_infer::Predictor::GetOutputHandle)
.def(
"run",
[](paddle_infer::Predictor &self, py::handle py_in_tensor_list) {
auto in_tensor_list =
CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0);
std::vector<paddle::Tensor> outputs;
self.Run(in_tensor_list, &outputs);
return py::handle(ToPyObject(outputs));
},
py::arg("inputs"))
.def("run", [](paddle_infer::Predictor &self) { self.Run(); }) .def("run", [](paddle_infer::Predictor &self) { self.Run(); })
.def("clone", .def("clone",
[](paddle_infer::Predictor &self) { return self.Clone(nullptr); }) [](paddle_infer::Predictor &self) { return self.Clone(nullptr); })
...@@ -1091,9 +1101,9 @@ void BindZeroCopyTensor(py::module *m) { ...@@ -1091,9 +1101,9 @@ void BindZeroCopyTensor(py::module *m) {
.def("copy_from_cpu", &ZeroCopyTensorCreate<int32_t>) .def("copy_from_cpu", &ZeroCopyTensorCreate<int32_t>)
.def("copy_from_cpu", &ZeroCopyTensorCreate<int64_t>) .def("copy_from_cpu", &ZeroCopyTensorCreate<int64_t>)
.def("copy_from_cpu", &ZeroCopyTensorCreate<float>) .def("copy_from_cpu", &ZeroCopyTensorCreate<float>)
.def("copy_from_cpu", &ZeroCopyTensorCreate<phi::dtype::float16>)
// NOTE(liuyuanle): double must be bound after float. // NOTE(liuyuanle): double must be bound after float.
.def("copy_from_cpu", &ZeroCopyTensorCreate<double>) .def("copy_from_cpu", &ZeroCopyTensorCreate<double>)
.def("copy_from_cpu", &ZeroCopyTensorCreate<paddle_infer::float16>)
.def("copy_from_cpu", &ZeroCopyTensorCreate<bool>) .def("copy_from_cpu", &ZeroCopyTensorCreate<bool>)
.def("copy_from_cpu", &ZeroCopyStringTensorCreate) .def("copy_from_cpu", &ZeroCopyStringTensorCreate)
.def("copy_to_cpu", &ZeroCopyTensorToNumpy) .def("copy_to_cpu", &ZeroCopyTensorToNumpy)
...@@ -1116,10 +1126,9 @@ void BindPaddleInferTensor(py::module *m) { ...@@ -1116,10 +1126,9 @@ void BindPaddleInferTensor(py::module *m) {
.def("_copy_from_cpu_bind", &PaddleInferTensorCreate<int32_t>) .def("_copy_from_cpu_bind", &PaddleInferTensorCreate<int32_t>)
.def("_copy_from_cpu_bind", &PaddleInferTensorCreate<int64_t>) .def("_copy_from_cpu_bind", &PaddleInferTensorCreate<int64_t>)
.def("_copy_from_cpu_bind", &PaddleInferTensorCreate<float>) .def("_copy_from_cpu_bind", &PaddleInferTensorCreate<float>)
.def("_copy_from_cpu_bind", &PaddleInferTensorCreate<phi::dtype::float16>)
// NOTE(liuyuanle): double must be bound after float. // NOTE(liuyuanle): double must be bound after float.
.def("_copy_from_cpu_bind", &PaddleInferTensorCreate<double>) .def("_copy_from_cpu_bind", &PaddleInferTensorCreate<double>)
.def("_copy_from_cpu_bind",
&PaddleInferTensorCreate<paddle_infer::float16>)
.def("_copy_from_cpu_bind", &PaddleInferTensorCreate<bool>) .def("_copy_from_cpu_bind", &PaddleInferTensorCreate<bool>)
.def("_copy_from_cpu_bind", &PaddleInferStringTensorCreate) .def("_copy_from_cpu_bind", &PaddleInferStringTensorCreate)
.def("_share_external_data_bind", &PaddleInferShareExternalData) .def("_share_external_data_bind", &PaddleInferShareExternalData)
......
...@@ -416,7 +416,7 @@ class PADDLE_API Tensor final { ...@@ -416,7 +416,7 @@ class PADDLE_API Tensor final {
/** /**
* @brief Return the name of Tensor. * @brief Return the name of Tensor.
* @note Used to adapt original execution mechanism and debug analysis * @note Used to adapt original execution mechanism and debug analysis
* in the development of new dygraph. It may be removed in the future. * in the development of new dygraph.
* *
* @return const std::string& * @return const std::string&
*/ */
...@@ -425,7 +425,7 @@ class PADDLE_API Tensor final { ...@@ -425,7 +425,7 @@ class PADDLE_API Tensor final {
/** /**
* @brief Set name of Tensor. * @brief Set name of Tensor.
* @note Used to adapt original execution mechanism and debug analysis * @note Used to adapt original execution mechanism and debug analysis
* in the development of new dygraph. It may be removed in the future. * in the development of new dygraph.
* *
* @param const std::string& name * @param const std::string& name
*/ */
...@@ -657,7 +657,7 @@ class PADDLE_API Tensor final { ...@@ -657,7 +657,7 @@ class PADDLE_API Tensor final {
/** /**
* Tensor name: used to adapt original execution mechanism and debug analysis * Tensor name: used to adapt original execution mechanism and debug analysis
* in the development of new dygraph. It may be removed in the future. * in the development of new dygraph.
*/ */
std::string name_{""}; std::string name_{""};
......
...@@ -136,6 +136,7 @@ Tensor add_n_impl(const std::vector<Tensor>& x) { ...@@ -136,6 +136,7 @@ Tensor add_n_impl(const std::vector<Tensor>& x) {
Tensor copy_to_impl(const Tensor& x, Place place, bool blocking) { Tensor copy_to_impl(const Tensor& x, Place place, bool blocking) {
Tensor out; Tensor out;
copy(x, place, blocking, &out); copy(x, place, blocking, &out);
out.set_name(x.name());
return out; return out;
} }
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile
import unittest
import numpy as np
import paddle
from paddle.inference import Config, create_predictor
class TestNet(paddle.nn.Layer):
def __init__(self):
super().__init__()
self.fc1 = paddle.nn.Linear(4, 4)
self.fc2 = paddle.nn.Linear(4, 4)
def forward(self, x1, x2):
y1 = self.fc1(x1)
y2 = self.fc2(x2)
return y1 + y2
@unittest.skipIf(
not paddle.is_compiled_with_cuda(), 'should compile with cuda.'
)
class TestPredictorRunWithTensor(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
net = TestNet()
model = paddle.jit.to_static(
net,
input_spec=[
paddle.static.InputSpec(
shape=[None, 4], dtype='float32', name='input0'
),
paddle.static.InputSpec(
shape=[None, 4], dtype='float32', name='input1'
),
],
)
paddle.jit.save(
model,
os.path.join(
self.temp_dir.name, 'test_predictor_run_model/inference'
),
)
def tearDown(self):
self.temp_dir.cleanup()
def init_predictor(self):
config = Config(
os.path.join(
self.temp_dir.name,
'test_predictor_run_model/inference.pdmodel',
),
os.path.join(
self.temp_dir.name,
'test_predictor_run_model/inference.pdiparams',
),
)
config.enable_use_gpu(256, 0)
config.enable_memory_optim()
predictor = create_predictor(config)
return predictor
def get_inputs(self):
input0 = np.array([[1, 2, 3, 4], [2, 3, 4, 5]]).astype(np.float32)
input1 = np.array([[0.1, 0.2, 0.3, 0.4], [1.2, 1.3, 1.4, 1.5]]).astype(
np.float32
)
input0_tensor = paddle.to_tensor(input0)
input1_tensor = paddle.to_tensor(input1)
return [input0_tensor, input1_tensor]
def get_disorder_output(self):
predictor = self.init_predictor()
[input0_tensor, input1_tensor] = self.get_inputs()
input_names = predictor.get_input_names()
input0_tensor.name = input_names[0]
input1_tensor.name = input_names[1]
# disorder
inputs = [input1_tensor, input0_tensor]
outputs = predictor.run(inputs)
return outputs[0]
def get_inorder_output(self):
predictor = self.init_predictor()
[input0_tensor, input1_tensor] = self.get_inputs()
# inorder
inputs = [input0_tensor, input1_tensor]
outputs = predictor.run(inputs)
return outputs[0]
def test_output(self):
inorder_output = self.get_inorder_output()
disorder_output = self.get_disorder_output()
assert np.allclose(
inorder_output.numpy().flatten(), disorder_output.numpy().flatten()
)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册