diff --git a/paddle/fluid/inference/api/details/zero_copy_tensor.cc b/paddle/fluid/inference/api/details/zero_copy_tensor.cc index 0f540699b8ffea94c3f3aaf3354a0462e0e674a9..f60ff40c5da3e9e03c2cb3583263394cb82db805 100644 --- a/paddle/fluid/inference/api/details/zero_copy_tensor.cc +++ b/paddle/fluid/inference/api/details/zero_copy_tensor.cc @@ -33,9 +33,15 @@ void ZeroCopyTensor::Reshape(const std::vector &shape) { tensor->Resize(framework::make_ddim(shape)); } +#define EAGER_GET_TENSOR \ + if (!tensor_) { \ + tensor_ = FindTensor(); \ + } \ + auto *tensor = static_cast(tensor_); + template T *ZeroCopyTensor::mutable_data(PaddlePlace place) { - auto *tensor = static_cast(FindTensor()); + EAGER_GET_TENSOR; switch (static_cast(place)) { case static_cast(PaddlePlace::kCPU): { return tensor->mutable_data(platform::CPUPlace()); @@ -52,7 +58,7 @@ T *ZeroCopyTensor::mutable_data(PaddlePlace place) { template T *ZeroCopyTensor::data(PaddlePlace *place, int *size) const { - auto *tensor = static_cast(FindTensor()); + EAGER_GET_TENSOR; auto *res = tensor->data(); if (platform::is_cpu_place(tensor->place())) { @@ -87,13 +93,13 @@ void *ZeroCopyTensor::FindTensor() const { } std::vector ZeroCopyTensor::shape() const { - auto *tensor = static_cast(FindTensor()); - PADDLE_ENFORCE(tensor, "not found tensor called %s in the scope", name_); + EAGER_GET_TENSOR; + PADDLE_ENFORCE(tensor_, "not found tensor called %s in the scope", name_); return framework::vectorize(tensor->dims()); } void ZeroCopyTensor::SetLoD(const std::vector> &x) { - auto *tensor = static_cast(FindTensor()); + EAGER_GET_TENSOR; framework::LoD lod; for (auto &level : x) { lod.emplace_back(level); @@ -102,8 +108,8 @@ void ZeroCopyTensor::SetLoD(const std::vector> &x) { } std::vector> ZeroCopyTensor::lod() const { + EAGER_GET_TENSOR; std::vector> res; - auto *tensor = static_cast(FindTensor()); for (auto &level : tensor->lod()) { res.emplace_back(level); } diff --git a/paddle/fluid/inference/api/paddle_api.h b/paddle/fluid/inference/api/paddle_api.h index 832c8cdf2849279c4c32a81e9f81ef522c401b86..d9edcf7cc5eefbd64883b679b39838f3f0e6a993 100644 --- a/paddle/fluid/inference/api/paddle_api.h +++ b/paddle/fluid/inference/api/paddle_api.h @@ -146,6 +146,9 @@ class ZeroCopyTensor { bool input_or_output_; friend class AnalysisPredictor; void* scope_{nullptr}; + // The corresponding tensor pointer inside Paddle workspace is cached for + // performance. + mutable void* tensor_{nullptr}; }; /** A simple Inference API for Paddle.