dense_tensor_kernels.cc 13.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/infrt/kernel/phi/dense_tensor_kernels.h"
W
Wilber 已提交
16
#include <memory>
W
Wilber 已提交
17
#include "llvm/Support/ErrorHandling.h"
W
Wilber 已提交
18
#include "paddle/infrt/backends/host/phi_allocator.h"
19
#include "paddle/infrt/common/string.h"
20 21
#include "paddle/infrt/dialect/phi/data_type.h"
#include "paddle/infrt/kernel/phi/context_kernels.h"
22 23
#include "paddle/infrt/paddle/model_parser.h"
#include "paddle/infrt/paddle/scope.h"
24
#include "paddle/infrt/tensor/tensor_map.h"
W
Wilber 已提交
25 26
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/common/place.h"
W
Wilber 已提交
27
#include "paddle/phi/core/allocator.h"
28
#include "paddle/phi/core/dense_tensor.h"
W
Wilber 已提交
29 30 31 32

#ifdef INFRT_WITH_GPU
#include <cuda_runtime.h>
#endif
33

34 35
namespace infrt {
namespace kernel {
36
namespace phi {
37

38 39
::phi::DenseTensor CreateDenseTensor(
    const ::phi::CPUContext& context,
40
    host_context::Attribute<std::vector<int64_t>> dims,
41
    host_context::Attribute<std::vector<int64_t>> lod,
42
    host_context::Attribute<::infrt::LayoutType> layout,
43 44 45
    host_context::Attribute<::infrt::PrecisionType> precision) {
  return ::phi::DenseTensor(
      const_cast<::phi::Allocator*>(&context.GetAllocator()),
46
      ::phi::DenseTensorMeta(ConvertPrecisionToPhi(precision.get()),
47
                             ::phi::make_ddim(dims.get()),
48
                             ConvertLayoutToPhi(layout.get()),
49
                             {}));
50 51
}

52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
::phi::DenseTensor CreateInitedDenseTensorF32(
    const ::phi::CPUContext& context,
    host_context::Attribute<std::vector<int64_t>> dims,
    host_context::Attribute<std::vector<int64_t>> lod,
    host_context::Attribute<::infrt::LayoutType> layout,
    host_context::Attribute<float> value) {
  ::phi::DenseTensor dense_tensor(
      const_cast<::phi::Allocator*>(&context.GetAllocator()),
      ::phi::DenseTensorMeta(
          ConvertPrecisionToPhi(::infrt::PrecisionType::FLOAT32),
          ::phi::make_ddim(dims.get()),
          ConvertLayoutToPhi(layout.get()),
          {}));
  float* a_data = dense_tensor.mutable_data<float>(::phi::CPUPlace());
  for (int64_t i = 0; i < dense_tensor.numel(); ++i) {
    a_data[i] = value.get();
  }
  return dense_tensor;
}

72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
::phi::DenseTensor CreateHostInitedDenseTensorF32(
    const ::phi::CPUContext& context,
    host_context::Attribute<std::vector<int64_t>> dims,
    host_context::Attribute<std::vector<int64_t>> lod,
    host_context::Attribute<::infrt::LayoutType> layout,
    host_context::Attribute<std::vector<float>> values) {
  ::phi::DenseTensor dense_tensor(
      const_cast<::phi::Allocator*>(&context.GetAllocator()),
      ::phi::DenseTensorMeta(
          ConvertPrecisionToPhi(::infrt::PrecisionType::FLOAT32),
          ::phi::make_ddim(dims.get()),
          ConvertLayoutToPhi(layout.get()),
          {}));
  CHECK_EQ(dense_tensor.numel(), static_cast<int64_t>(values.get().size()));
  float* data = dense_tensor.mutable_data<float>(::phi::CPUPlace());
  for (int64_t i = 0; i < dense_tensor.numel(); ++i) {
    data[i] = values.get()[i];
  }
  return dense_tensor;
}

W
Wilber 已提交
93 94 95 96 97 98 99 100 101 102 103 104 105 106
::phi::DenseTensor CreateGPUDenseTensor(
    const ::phi::GPUContext& context,
    host_context::Attribute<std::vector<int64_t>> dims,
    host_context::Attribute<std::vector<int64_t>> lod,
    host_context::Attribute<::infrt::LayoutType> layout,
    host_context::Attribute<::infrt::PrecisionType> precision) {
  return ::phi::DenseTensor(
      const_cast<::phi::Allocator*>(&context.GetAllocator()),
      ::phi::DenseTensorMeta(ConvertPrecisionToPhi(precision.get()),
                             ::phi::make_ddim(dims.get()),
                             ConvertLayoutToPhi(layout.get()),
                             {}));
}

107
void FillDenseTensorF32(::phi::DenseTensor* dense_tensor,
108
                        host_context::Attribute<std::vector<float>> value) {
W
Wilber 已提交
109
  auto place = dense_tensor->place();
110
  float* a_data = dense_tensor->mutable_data<float>(place);
W
Wilber 已提交
111 112 113 114 115 116 117 118 119 120 121 122 123 124
  if (place.GetType() == ::phi::AllocationType::CPU) {
    for (int64_t i = 0; i < dense_tensor->numel(); ++i) {
      a_data[i] = (value.get())[i];
    }
  } else if (place.GetType() == ::phi::AllocationType::GPU) {
#ifdef INFRT_WITH_GPU
    // TODO(wilber): how to set the stream parameter to copy with stream.
    cudaMemcpy(a_data,
               value.get().data(),
               sizeof(float) * value.get().size(),
               cudaMemcpyHostToDevice);
#endif
  } else {
    llvm_unreachable("temporarily not support other target.");
125 126
  }
}
127

128
void PrintDenseTensor(::phi::DenseTensor* dense_tensor) {
W
Wilber 已提交
129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167
#ifndef INFRT_WITH_GPU
#define PRINT_META_DATA(PHI_DATATYPE, DTYPE)                \
  case ::phi::DataType::PHI_DATATYPE: {                     \
    auto place = dense_tensor->place();                     \
    if (place.GetType() == ::phi::AllocationType::CPU) {    \
      DTYPE* data = dense_tensor->data<DTYPE>();            \
      if (dense_tensor->numel() == 0) break;                \
      std::cout << data[0];                                 \
      for (int64_t i = 1; i < dense_tensor->numel(); i++) { \
        std::cout << "," << data[i];                        \
      }                                                     \
    }                                                       \
    break;                                                  \
  }
#else
#define PRINT_META_DATA(PHI_DATATYPE, DTYPE)                     \
  case ::phi::DataType::PHI_DATATYPE: {                          \
    auto place = dense_tensor->place();                          \
    DTYPE* data = dense_tensor->data<DTYPE>();                   \
    if (dense_tensor->numel() == 0) break;                       \
    if (place.GetType() == ::phi::AllocationType::CPU) {         \
      std::cout << data[0];                                      \
      for (int64_t i = 1; i < dense_tensor->numel(); i++) {      \
        std::cout << "," << data[i];                             \
      }                                                          \
    } else if (place.GetType() == ::phi::AllocationType::GPU) {  \
      std::vector<DTYPE> host_data(dense_tensor->numel(), 0);    \
      cudaMemcpy(host_data.data(),                               \
                 data,                                           \
                 sizeof(DTYPE) * dense_tensor->numel(),          \
                 cudaMemcpyDeviceToHost);                        \
      std::cout << host_data[0];                                 \
      for (int64_t i = 1; i < dense_tensor->numel(); i++) {      \
        std::cout << "," << host_data[i];                        \
      }                                                          \
    } else {                                                     \
      llvm_unreachable("temporarily not support other target."); \
    }                                                            \
    break;                                                       \
168
  }
W
Wilber 已提交
169
#endif
170 171 172

  ::phi::DDim dims = dense_tensor->dims();
  std::cout << "dense_tensor: shape=shape" << dims.to_str() << ","
173
            << " value=[";
174 175 176 177 178 179 180 181 182
  switch (dense_tensor->dtype()) {
    PRINT_META_DATA(FLOAT32, float);
    PRINT_META_DATA(INT32, int32_t);
    default:
      std::cout << "Error! Unsupported data type!\n";
  }
  std::cout << "]\n";
#undef PRINT_META_DATA
}
183

184
::infrt::phi::DenseTensorMap LoadParameters(const std::string& file_path) {
185 186 187 188 189 190 191
  std::cout << "loading params from: " << file_path << std::endl;
  ::infrt::phi::DenseTensorMap map;

  const std::string model_path = file_path + "/__model__";
  auto pb_proto_prog = paddle::LoadProgram(model_path);
  auto main_block = pb_proto_prog->blocks(0);

W
Wilber 已提交
192 193 194 195 196 197
  ::phi::CPUContext ctx;
  auto allocator = std::make_unique<backends::CpuPhiAllocator>();
  const auto* allocator_ptr = allocator.get();
  ctx.SetAllocator(allocator_ptr);
  ctx.SetHostAllocator(allocator_ptr);
  ctx.SetZeroAllocator(allocator_ptr);
198 199 200 201 202 203 204 205 206
  for (auto& var : main_block.vars()) {
    if (var.name() == "feed" || var.name() == "fetch" || !var.persistable())
      continue;
    std::string param_path = file_path + "/" + var.name();
    std::ifstream param_file(param_path, std::ios::binary);
    switch (var.type().type()) {
      case ::paddle::framework::proto::VarType_Type_LOD_TENSOR: {
        std::unique_ptr<::phi::DenseTensor> tensor{
            std::make_unique<::phi::DenseTensor>()};
W
Wilber 已提交
207
        ::infrt::paddle::DeserializeFromStream(param_file, tensor.get(), ctx);
208 209 210 211 212 213 214 215 216 217 218 219
        map.SetDenseTensor(var.name(), std::move(tensor));
      } break;
      default: {
        LOG(WARNING) << "Var `" << var.name() << "` type `"
                     << static_cast<int>(var.type().type())
                     << "` has not been supported now.";
      }
    }
  }
  return map;
}

220 221 222 223 224 225 226
::infrt::phi::DenseTensorMap LoadParams(
    host_context::Attribute<std::string> path) {
  return LoadParameters(path.get());
}

::infrt::phi::DenseTensorMap LoadCombinedParameters(
    const std::string& model_path, const std::string& params_path) {
227 228
  ::infrt::phi::DenseTensorMap map;

229
  auto pb_proto_prog = paddle::LoadProgram(model_path);
230 231
  auto main_block = pb_proto_prog->blocks(0);

232
  std::ifstream param_file(params_path, std::ios::binary);
233 234 235 236 237 238 239 240 241 242 243 244 245 246

  std::set<std::string> tmp;
  for (auto& var : main_block.vars()) {
    if (var.name() == "feed" || var.name() == "fetch" || !var.persistable()) {
      continue;
    }
    if (var.type().type() ==
        ::paddle::framework::proto::VarType_Type_LOD_TENSOR) {
      tmp.emplace(var.name());
    } else {
      llvm_unreachable("the tensor type is illegal.");
    }
  }

W
Wilber 已提交
247 248 249 250 251 252
  ::phi::CPUContext ctx;
  auto allocator = std::make_unique<backends::CpuPhiAllocator>();
  const auto* allocator_ptr = allocator.get();
  ctx.SetAllocator(allocator_ptr);
  ctx.SetHostAllocator(allocator_ptr);
  ctx.SetZeroAllocator(allocator_ptr);
253 254 255
  for (auto& var : tmp) {
    std::unique_ptr<::phi::DenseTensor> tensor{
        std::make_unique<::phi::DenseTensor>()};
W
Wilber 已提交
256
    ::infrt::paddle::DeserializeFromStream(param_file, tensor.get(), ctx);
257 258 259 260 261 262
    map.SetDenseTensor(var, std::move(tensor));
  }

  return map;
}

263 264 265 266 267 268
::infrt::phi::DenseTensorMap LoadCombinedParams(
    host_context::Attribute<std::string> model_path,
    host_context::Attribute<std::string> params_path) {
  return LoadCombinedParameters(model_path.get(), params_path.get());
}

269 270 271 272 273 274 275 276 277 278 279 280
::phi::DenseTensor TensorMapGetTensor(
    const ::infrt::phi::DenseTensorMap& map,
    host_context::Attribute<std::string> name) {
  auto* tensor = map.GetDenseTensor(name.get());
  CHECK(tensor);
  return *tensor;
}

int32_t TensorMapGetSize(const ::infrt::phi::DenseTensorMap& map) {
  return map.size();
}

W
Wilber 已提交
281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311
#ifdef INFRT_WITH_GPU
inline size_t SizeOfDataType(::phi::DataType data_type) {
  switch (data_type) {
    case ::phi::DataType::BOOL:
    case ::phi::DataType::UINT8:
    case ::phi::DataType::INT8:
      return 1;
    case ::phi::DataType::BFLOAT16:
    case ::phi::DataType::FLOAT16:
    case ::phi::DataType::INT16:
    case ::phi::DataType::UINT16:
      return 2;
    case ::phi::DataType::FLOAT32:
    case ::phi::DataType::INT32:
    case ::phi::DataType::UINT32:
      return 4;
    case ::phi::DataType::FLOAT64:
    case ::phi::DataType::INT64:
    case ::phi::DataType::UINT64:
    case ::phi::DataType::COMPLEX64:
      return 8;
    case ::phi::DataType::COMPLEX128:
      return 16;
    case ::phi::DataType::UNDEFINED:
      return 0;
    default:
      llvm_unreachable("should not reach here");
      return 0;
  }
  return 0;
}
312 313 314 315
void GpuMemCpy(const ::phi::DenseTensor& input,
               const ::phi::GPUContext& context,
               bool d2h,
               ::phi::DenseTensor* output) {
W
Wilber 已提交
316 317
  if (d2h) {
    CHECK(input.place().GetType() == ::phi::AllocationType::GPU);
318 319 320 321 322 323 324 325 326 327

    // TODO(wilber): Just a trick to avoid malloc.
    if (input.numel() > output->numel()) {
      // TODO(wilber): Use pinned memory.
      output->Resize(input.dims());
      context.HostAlloc(
          output, input.dtype(), input.numel() * SizeOfDataType(input.dtype()));
    }

    cudaMemcpyAsync(output->data(),
W
Wilber 已提交
328 329 330
                    input.data(),
                    SizeOfDataType(input.dtype()) * input.numel(),
                    cudaMemcpyDeviceToHost,
331 332 333
                    context.stream());
    // TODO(wilber): Ir add sync op.
    cudaStreamSynchronize(context.stream());
W
Wilber 已提交
334 335 336 337
  } else {
    // h2d
    CHECK(input.place().GetType() == ::phi::AllocationType::CPU ||
          input.place().GetType() == ::phi::AllocationType::GPUPINNED);
338 339 340 341 342 343 344 345 346 347 348 349

    if (input.numel() > output->numel()) {
      output->Resize(input.dims());
      context.Alloc(output,
                    input.dtype(),
                    input.numel() * SizeOfDataType(input.dtype()),
                    false);

    } else {
      output->Resize(input.dims());
    }

W
Wilber 已提交
350
    // TODO(wilber): Add sync op and stream.
351
    cudaMemcpyAsync(output->data(),
W
Wilber 已提交
352 353 354
                    input.data(),
                    SizeOfDataType(input.dtype()) * input.numel(),
                    cudaMemcpyHostToDevice,
355
                    context.stream());
W
Wilber 已提交
356 357 358 359
  }
}
#endif

360
}  // namespace phi
361 362
}  // namespace kernel
}  // namespace infrt