zero_copy_tensor.cc 8.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
N
nhzlx 已提交
18
#include "paddle/fluid/memory/memcpy.h"
19 20 21 22 23
#include "paddle/fluid/platform/enforce.h"

namespace paddle {

void ZeroCopyTensor::Reshape(const std::vector<int> &shape) {
W
Wilber 已提交
24 25 26 27 28 29 30 31 32 33
  PADDLE_ENFORCE_EQ(
      name_.empty(), false,
      platform::errors::PreconditionNotMet(
          "Need to SetName first, so that the corresponding tensor can "
          "be retrieved."));
  PADDLE_ENFORCE_EQ(input_or_output_, true,
                    platform::errors::PermissionDenied(
                        "Can't reshape the output tensor, it is readonly"));
  PADDLE_ENFORCE_NOT_NULL(scope_, platform::errors::PreconditionNotMet(
                                      "The scope should not be nullptr."));
34 35
  auto *scope = static_cast<framework::Scope *>(scope_);
  auto *var = scope->FindVar(name_);
W
Wilber 已提交
36 37 38
  PADDLE_ENFORCE_NOT_NULL(
      var, platform::errors::PreconditionNotMet(
               "No tensor called [%s] in the runtime scope", name_));
39 40 41 42
  auto *tensor = var->GetMutable<framework::LoDTensor>();
  tensor->Resize(framework::make_ddim(shape));
}

43 44 45 46 47 48
#define EAGER_GET_TENSOR    \
  if (!tensor_) {           \
    tensor_ = FindTensor(); \
  }                         \
  auto *tensor = static_cast<framework::LoDTensor *>(tensor_);

49 50
template <typename T>
T *ZeroCopyTensor::mutable_data(PaddlePlace place) {
51
  EAGER_GET_TENSOR;
52 53
  PADDLE_ENFORCE_GT(
      tensor->numel(), 0,
W
Wilber 已提交
54 55 56 57
      platform::errors::PreconditionNotMet(
          "You should call ZeroCopyTensor::Reshape(const std::vector<int> "
          "&shape)"
          "function before retrieving mutable_data from input tensor."));
58 59 60 61 62
  switch (static_cast<int>(place)) {
    case static_cast<int>(PaddlePlace::kCPU): {
      return tensor->mutable_data<T>(platform::CPUPlace());
    }
    case static_cast<int>(PaddlePlace::kGPU): {
63
      return tensor->mutable_data<T>(platform::CUDAPlace(device_));
64 65
    }
    default:
W
Wilber 已提交
66 67
      PADDLE_THROW(platform::errors::Unavailable("Unsupported place: %d",
                                                 static_cast<int>(place)));
68 69 70 71 72 73
      break;
  }
  return nullptr;
}

template <typename T>
74
T *ZeroCopyTensor::data(PaddlePlace *place, int *size) const {
75
  EAGER_GET_TENSOR;
76 77 78 79 80 81 82 83 84 85 86 87 88 89
  auto *res = tensor->data<T>();

  if (platform::is_cpu_place(tensor->place())) {
    *place = PaddlePlace::kCPU;
  } else if (platform::is_gpu_place(tensor->place())) {
    *place = PaddlePlace::kGPU;
  } else {
    *place = PaddlePlace::kUNK;
  }

  *size = tensor->numel();
  return res;
}

N
nhzlx 已提交
90
PaddleDType ZeroCopyTensor::type() const {
91 92 93 94 95 96
  EAGER_GET_TENSOR;
  auto type = tensor->type();
  if (type == framework::proto::VarType::FP32) {
    return PaddleDType::FLOAT32;
  } else if (type == framework::proto::VarType::INT64) {
    return PaddleDType::INT64;
N
nhzlx 已提交
97 98
  } else if (type == framework::proto::VarType::INT32) {
    return PaddleDType::INT32;
99 100
  } else if (type == framework::proto::VarType::UINT8) {
    return PaddleDType::UINT8;
101 102 103 104
  }
  return PaddleDType::FLOAT32;
}

N
nhzlx 已提交
105 106 107
template <typename T>
void ZeroCopyTensor::copy_from_cpu(const T *data) {
  EAGER_GET_TENSOR;
W
Wilber 已提交
108 109 110 111 112
  PADDLE_ENFORCE_GE(tensor->numel(), 0,
                    platform::errors::PreconditionNotMet(
                        "You should call ZeroCopyTensor::Reshape(const "
                        "std::vector<int> &shape)"
                        "function before copying data from cpu."));
N
nhzlx 已提交
113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
  size_t ele_size = tensor->numel() * sizeof(T);

  if (place_ == PaddlePlace::kCPU) {
    auto *t_data = tensor->mutable_data<T>(platform::CPUPlace());
    std::memcpy(static_cast<void *>(t_data), data, ele_size);
  } else {
#ifdef PADDLE_WITH_CUDA
    platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
    platform::CUDAPlace gpu_place(device_);
    auto *t_data = tensor->mutable_data<T>(gpu_place);
    auto *dev_ctx =
        static_cast<const platform::CUDADeviceContext *>(pool.Get(gpu_place));

    memory::Copy(gpu_place, static_cast<void *>(t_data), platform::CPUPlace(),
                 data, ele_size, dev_ctx->stream());
#else
W
Wilber 已提交
129 130
    PADDLE_THROW(platform::errors::Unavailable(
        "Not compiled with CUDA, should not reach here."));
N
nhzlx 已提交
131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
#endif
  }
}

template <typename T>
void ZeroCopyTensor::copy_to_cpu(T *data) {
  EAGER_GET_TENSOR;
  auto ele_num = tensor->numel();
  auto *t_data = tensor->data<T>();
  auto t_place = tensor->place();

  if (platform::is_cpu_place(t_place)) {
    std::memcpy(static_cast<void *>(data), t_data, ele_num * sizeof(T));
  } else {
#ifdef PADDLE_WITH_CUDA
    platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
147
    auto gpu_place = BOOST_GET_CONST(platform::CUDAPlace, t_place);
N
nhzlx 已提交
148 149 150 151
    auto *dev_ctx =
        static_cast<const platform::CUDADeviceContext *>(pool.Get(gpu_place));
    memory::Copy(platform::CPUPlace(), static_cast<void *>(data), gpu_place,
                 t_data, ele_num * sizeof(T), dev_ctx->stream());
152 153

    cudaStreamSynchronize(dev_ctx->stream());
N
nhzlx 已提交
154
#else
W
Wilber 已提交
155 156
    PADDLE_THROW(platform::errors::Unavailable(
        "Not compile with CUDA, should not reach here."));
N
nhzlx 已提交
157 158 159
#endif
  }
}
160 161 162 163 164 165 166 167
template PD_INFER_DECL void ZeroCopyTensor::copy_from_cpu<float>(
    const float *data);
template PD_INFER_DECL void ZeroCopyTensor::copy_from_cpu<int64_t>(
    const int64_t *data);
template PD_INFER_DECL void ZeroCopyTensor::copy_from_cpu<int32_t>(
    const int32_t *data);
template PD_INFER_DECL void ZeroCopyTensor::copy_from_cpu<uint8_t>(
    const uint8_t *data);
168 169 170
template PD_INFER_DECL void ZeroCopyTensor::copy_from_cpu<int8_t>(
    const int8_t *data);

171 172 173 174
template PD_INFER_DECL void ZeroCopyTensor::copy_to_cpu<float>(float *data);
template PD_INFER_DECL void ZeroCopyTensor::copy_to_cpu<int64_t>(int64_t *data);
template PD_INFER_DECL void ZeroCopyTensor::copy_to_cpu<int32_t>(int32_t *data);
template PD_INFER_DECL void ZeroCopyTensor::copy_to_cpu<uint8_t>(uint8_t *data);
175
template PD_INFER_DECL void ZeroCopyTensor::copy_to_cpu<int8_t>(int8_t *data);
176 177 178 179 180 181 182 183 184

template PD_INFER_DECL float *ZeroCopyTensor::data<float>(PaddlePlace *place,
                                                          int *size) const;
template PD_INFER_DECL int64_t *ZeroCopyTensor::data<int64_t>(
    PaddlePlace *place, int *size) const;
template PD_INFER_DECL int32_t *ZeroCopyTensor::data<int32_t>(
    PaddlePlace *place, int *size) const;
template PD_INFER_DECL uint8_t *ZeroCopyTensor::data<uint8_t>(
    PaddlePlace *place, int *size) const;
185 186 187
template PD_INFER_DECL int8_t *ZeroCopyTensor::data<int8_t>(PaddlePlace *place,
                                                            int *size) const;

188 189 190 191 192 193 194 195
template PD_INFER_DECL float *ZeroCopyTensor::mutable_data<float>(
    PaddlePlace place);
template PD_INFER_DECL int64_t *ZeroCopyTensor::mutable_data<int64_t>(
    PaddlePlace place);
template PD_INFER_DECL int32_t *ZeroCopyTensor::mutable_data<int32_t>(
    PaddlePlace place);
template PD_INFER_DECL uint8_t *ZeroCopyTensor::mutable_data<uint8_t>(
    PaddlePlace place);
196 197
template PD_INFER_DECL int8_t *ZeroCopyTensor::mutable_data<int8_t>(
    PaddlePlace place);
198 199

void *ZeroCopyTensor::FindTensor() const {
W
Wilber 已提交
200 201 202 203 204 205 206
  PADDLE_ENFORCE_EQ(
      name_.empty(), false,
      platform::errors::PreconditionNotMet(
          "Need to SetName first, so that the corresponding tensor can "
          "be retrieved."));
  PADDLE_ENFORCE_NOT_NULL(scope_, platform::errors::PreconditionNotMet(
                                      "The scope should not be nullptr."));
207 208
  auto *scope = static_cast<framework::Scope *>(scope_);
  auto *var = scope->FindVar(name_);
W
Wilber 已提交
209 210 211
  PADDLE_ENFORCE_NOT_NULL(
      var, platform::errors::PreconditionNotMet(
               "No tensor called [%s] in the runtime scope", name_));
212 213 214 215
  auto *tensor = var->GetMutable<framework::LoDTensor>();
  return tensor;
}

N
nhzlx 已提交
216
std::vector<int> ZeroCopyTensor::shape() const {
217
  EAGER_GET_TENSOR;
W
Wilber 已提交
218 219 220
  PADDLE_ENFORCE_NOT_NULL(
      tensor_, platform::errors::PreconditionNotMet(
                   "Not found tensor called %s in the scope", name_));
221
  return framework::vectorize<int>(tensor->dims());
222 223 224
}

void ZeroCopyTensor::SetLoD(const std::vector<std::vector<size_t>> &x) {
225
  EAGER_GET_TENSOR;
226 227 228 229 230 231 232 233
  framework::LoD lod;
  for (auto &level : x) {
    lod.emplace_back(level);
  }
  tensor->set_lod(lod);
}

std::vector<std::vector<size_t>> ZeroCopyTensor::lod() const {
234
  EAGER_GET_TENSOR;
235 236 237 238 239 240 241 242
  std::vector<std::vector<size_t>> res;
  for (auto &level : tensor->lod()) {
    res.emplace_back(level);
  }
  return res;
}

}  // namespace paddle