未验证 提交 e07900d3 编写于 作者: Y Yan Chunwei 提交者: GitHub

cache tensor ptr in ZeroCopyTensor (#15352)

上级 b7916440
......@@ -33,9 +33,15 @@ void ZeroCopyTensor::Reshape(const std::vector<int> &shape) {
tensor->Resize(framework::make_ddim(shape));
}
#define EAGER_GET_TENSOR \
if (!tensor_) { \
tensor_ = FindTensor(); \
} \
auto *tensor = static_cast<framework::LoDTensor *>(tensor_);
template <typename T>
T *ZeroCopyTensor::mutable_data(PaddlePlace place) {
auto *tensor = static_cast<framework::LoDTensor *>(FindTensor());
EAGER_GET_TENSOR;
switch (static_cast<int>(place)) {
case static_cast<int>(PaddlePlace::kCPU): {
return tensor->mutable_data<T>(platform::CPUPlace());
......@@ -52,7 +58,7 @@ T *ZeroCopyTensor::mutable_data(PaddlePlace place) {
template <typename T>
T *ZeroCopyTensor::data(PaddlePlace *place, int *size) const {
auto *tensor = static_cast<framework::LoDTensor *>(FindTensor());
EAGER_GET_TENSOR;
auto *res = tensor->data<T>();
if (platform::is_cpu_place(tensor->place())) {
......@@ -87,13 +93,13 @@ void *ZeroCopyTensor::FindTensor() const {
}
std::vector<int64_t> ZeroCopyTensor::shape() const {
auto *tensor = static_cast<framework::LoDTensor *>(FindTensor());
PADDLE_ENFORCE(tensor, "not found tensor called %s in the scope", name_);
EAGER_GET_TENSOR;
PADDLE_ENFORCE(tensor_, "not found tensor called %s in the scope", name_);
return framework::vectorize(tensor->dims());
}
void ZeroCopyTensor::SetLoD(const std::vector<std::vector<size_t>> &x) {
auto *tensor = static_cast<framework::LoDTensor *>(FindTensor());
EAGER_GET_TENSOR;
framework::LoD lod;
for (auto &level : x) {
lod.emplace_back(level);
......@@ -102,8 +108,8 @@ void ZeroCopyTensor::SetLoD(const std::vector<std::vector<size_t>> &x) {
}
std::vector<std::vector<size_t>> ZeroCopyTensor::lod() const {
EAGER_GET_TENSOR;
std::vector<std::vector<size_t>> res;
auto *tensor = static_cast<framework::LoDTensor *>(FindTensor());
for (auto &level : tensor->lod()) {
res.emplace_back(level);
}
......
......@@ -146,6 +146,9 @@ class ZeroCopyTensor {
bool input_or_output_;
friend class AnalysisPredictor;
void* scope_{nullptr};
// The corresponding tensor pointer inside Paddle workspace is cached for
// performance.
mutable void* tensor_{nullptr};
};
/** A simple Inference API for Paddle.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册