未验证 提交 e07900d3 编写于 作者: Y Yan Chunwei 提交者: GitHub

cache tensor ptr in ZeroCopyTensor (#15352)

上级 b7916440
...@@ -33,9 +33,15 @@ void ZeroCopyTensor::Reshape(const std::vector<int> &shape) { ...@@ -33,9 +33,15 @@ void ZeroCopyTensor::Reshape(const std::vector<int> &shape) {
tensor->Resize(framework::make_ddim(shape)); tensor->Resize(framework::make_ddim(shape));
} }
#define EAGER_GET_TENSOR \
if (!tensor_) { \
tensor_ = FindTensor(); \
} \
auto *tensor = static_cast<framework::LoDTensor *>(tensor_);
template <typename T> template <typename T>
T *ZeroCopyTensor::mutable_data(PaddlePlace place) { T *ZeroCopyTensor::mutable_data(PaddlePlace place) {
auto *tensor = static_cast<framework::LoDTensor *>(FindTensor()); EAGER_GET_TENSOR;
switch (static_cast<int>(place)) { switch (static_cast<int>(place)) {
case static_cast<int>(PaddlePlace::kCPU): { case static_cast<int>(PaddlePlace::kCPU): {
return tensor->mutable_data<T>(platform::CPUPlace()); return tensor->mutable_data<T>(platform::CPUPlace());
...@@ -52,7 +58,7 @@ T *ZeroCopyTensor::mutable_data(PaddlePlace place) { ...@@ -52,7 +58,7 @@ T *ZeroCopyTensor::mutable_data(PaddlePlace place) {
template <typename T> template <typename T>
T *ZeroCopyTensor::data(PaddlePlace *place, int *size) const { T *ZeroCopyTensor::data(PaddlePlace *place, int *size) const {
auto *tensor = static_cast<framework::LoDTensor *>(FindTensor()); EAGER_GET_TENSOR;
auto *res = tensor->data<T>(); auto *res = tensor->data<T>();
if (platform::is_cpu_place(tensor->place())) { if (platform::is_cpu_place(tensor->place())) {
...@@ -87,13 +93,13 @@ void *ZeroCopyTensor::FindTensor() const { ...@@ -87,13 +93,13 @@ void *ZeroCopyTensor::FindTensor() const {
} }
std::vector<int64_t> ZeroCopyTensor::shape() const { std::vector<int64_t> ZeroCopyTensor::shape() const {
auto *tensor = static_cast<framework::LoDTensor *>(FindTensor()); EAGER_GET_TENSOR;
PADDLE_ENFORCE(tensor, "not found tensor called %s in the scope", name_); PADDLE_ENFORCE(tensor_, "not found tensor called %s in the scope", name_);
return framework::vectorize(tensor->dims()); return framework::vectorize(tensor->dims());
} }
void ZeroCopyTensor::SetLoD(const std::vector<std::vector<size_t>> &x) { void ZeroCopyTensor::SetLoD(const std::vector<std::vector<size_t>> &x) {
auto *tensor = static_cast<framework::LoDTensor *>(FindTensor()); EAGER_GET_TENSOR;
framework::LoD lod; framework::LoD lod;
for (auto &level : x) { for (auto &level : x) {
lod.emplace_back(level); lod.emplace_back(level);
...@@ -102,8 +108,8 @@ void ZeroCopyTensor::SetLoD(const std::vector<std::vector<size_t>> &x) { ...@@ -102,8 +108,8 @@ void ZeroCopyTensor::SetLoD(const std::vector<std::vector<size_t>> &x) {
} }
std::vector<std::vector<size_t>> ZeroCopyTensor::lod() const { std::vector<std::vector<size_t>> ZeroCopyTensor::lod() const {
EAGER_GET_TENSOR;
std::vector<std::vector<size_t>> res; std::vector<std::vector<size_t>> res;
auto *tensor = static_cast<framework::LoDTensor *>(FindTensor());
for (auto &level : tensor->lod()) { for (auto &level : tensor->lod()) {
res.emplace_back(level); res.emplace_back(level);
} }
......
...@@ -146,6 +146,9 @@ class ZeroCopyTensor { ...@@ -146,6 +146,9 @@ class ZeroCopyTensor {
bool input_or_output_; bool input_or_output_;
friend class AnalysisPredictor; friend class AnalysisPredictor;
void* scope_{nullptr}; void* scope_{nullptr};
// The corresponding tensor pointer inside Paddle workspace is cached for
// performance.
mutable void* tensor_{nullptr};
}; };
/** A simple Inference API for Paddle. /** A simple Inference API for Paddle.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册