data_transform.cc 21.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/phi/api/lib/data_transform.h"
16

17 18
#include "glog/logging.h"

W
wanghuancoder 已提交
19 20
#include "gflags/gflags.h"
#include "paddle/fluid/platform/device_context.h"
21
#include "paddle/phi/api/lib/kernel_dispatch.h"
22
#include "paddle/phi/api/lib/utils/allocator.h"
23
#include "paddle/phi/backends/context_pool.h"
24
#include "paddle/phi/core/kernel_registry.h"
25
#include "paddle/phi/core/tensor_utils.h"
W
wanghuancoder 已提交
26
#include "paddle/phi/core/visit_type.h"
27
#include "paddle/phi/kernels/cast_kernel.h"
W
wanghuancoder 已提交
28
#include "paddle/phi/kernels/contiguous_kernel.h"
29
#include "paddle/phi/kernels/transfer_layout_kernel.h"
W
wanghuancoder 已提交
30
DECLARE_bool(use_stride_kernel);
31 32 33 34 35 36 37 38 39 40 41 42

namespace paddle {
namespace experimental {

inline bool NeedTransformDataType(const DataType& input,
                                  const DataType& target,
                                  const TransformFlag& transform_flag) {
  return input != target &&
         (transform_flag.need_trans_data_type() ||
          target == DataType::COMPLEX64 || target == DataType::COMPLEX128);
}

43
inline bool NeedTransformLayout(const DataLayout& input,
44
                                const DataLayout& target,
45
                                const phi::Place& place,
46
                                const TransformFlag& transform_flag) {
W
wanghuancoder 已提交
47 48 49 50
  if (FLAGS_use_stride_kernel && target == DataLayout::STRIDED) {
    return false;
  }

51 52 53
  bool ret = transform_flag.need_trans_layout() &&
             (input != DataLayout::ALL_LAYOUT &&
              target != DataLayout::ALL_LAYOUT && input != target);
54
  if (place.GetType() == phi::AllocationType::GPU) {
55 56
    return false;
  }
57 58 59
  return ret;
}

W
wanghuancoder 已提交
60 61 62 63 64
inline bool NeedTransform2Contiguous(bool is_stride_kernel,
                                     bool is_contiguous) {
  return FLAGS_use_stride_kernel && !is_stride_kernel && !is_contiguous;
}

65 66
inline phi::DenseTensor TransDataLayout(const phi::DenseTensor& tensor,
                                        DataLayout layout) {
67
  auto& pool = phi::DeviceContextPool::Instance();
68 69
  VLOG(3) << "DataLayoutTransform src_layout: " << tensor.layout()
          << " dst_layout: " << layout;
70
  if (tensor.place().GetType() == phi::AllocationType::CPU) {
71 72
    auto* dev_ctx = static_cast<phi::CPUContext*>(pool.Get(tensor.place()));
    return phi::TransferLayout(*dev_ctx, tensor, layout);
73
  } else {
74
    PADDLE_THROW(phi::errors::PreconditionNotMet(
75 76
        "Unsupported data layout cast from CPU to GPU."));
  }
77
  return tensor;
78 79 80
}

template <typename Context>
81
phi::DenseTensor CastDataType(const Context& dev_ctx,
82 83
                              const phi::DenseTensor& tensor,
                              DataType dtype) {
84 85
  switch (tensor.dtype()) {
    case DataType::FLOAT32:
86
      return phi::Cast<float>(dev_ctx, tensor, dtype);
87
    case DataType::FLOAT64:
88
      return phi::Cast<double>(dev_ctx, tensor, dtype);
89
    case DataType::INT32:
90
      return phi::Cast<int32_t>(dev_ctx, tensor, dtype);
91
    case DataType::INT64:
92
      return phi::Cast<int64_t>(dev_ctx, tensor, dtype);
93
    case DataType::FLOAT16:
94
      return phi::Cast<phi::dtype::float16>(dev_ctx, tensor, dtype);
95
    case DataType::BFLOAT16:
96
      return phi::Cast<phi::dtype::bfloat16>(dev_ctx, tensor, dtype);
97
    case DataType::BOOL:
98
      return phi::Cast<bool>(dev_ctx, tensor, dtype);
99
    case DataType::INT16:
100
      return phi::Cast<int16_t>(dev_ctx, tensor, dtype);
101
    case DataType::UINT8:
102
      return phi::Cast<uint8_t>(dev_ctx, tensor, dtype);
103
    default:
104
      PADDLE_THROW(phi::errors::Unimplemented(
105 106 107 108 109 110
          "Data type (%s) is not supported when casting data type.",
          tensor.dtype()));
  }
}

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
111
phi::DenseTensor CastDataType(const phi::GPUContext& dev_ctx,
112 113
                              const phi::DenseTensor& tensor,
                              DataType dtype) {
114 115
  switch (tensor.dtype()) {
    case DataType::FLOAT32:
116
      return phi::Cast<float>(dev_ctx, tensor, dtype);
117
    case DataType::FLOAT64:
118
      return phi::Cast<double>(dev_ctx, tensor, dtype);
119
    case DataType::INT32:
120
      return phi::Cast<int32_t>(dev_ctx, tensor, dtype);
121
    case DataType::INT64:
122
      return phi::Cast<int64_t>(dev_ctx, tensor, dtype);
123
    case DataType::FLOAT16:
124
      return phi::Cast<phi::dtype::float16>(dev_ctx, tensor, dtype);
125
    case DataType::BOOL:
126
      return phi::Cast<bool>(dev_ctx, tensor, dtype);
127
    case DataType::INT16:
128
      return phi::Cast<int16_t>(dev_ctx, tensor, dtype);
129
    case DataType::UINT8:
130
      return phi::Cast<uint8_t>(dev_ctx, tensor, dtype);
131
    default:
132
      PADDLE_THROW(phi::errors::Unimplemented(
133 134 135 136 137 138
          "Data type (%s) is not supported when casting data type.",
          tensor.dtype()));
  }
}
#endif

139 140
inline phi::DenseTensor TransDataType(const phi::DenseTensor& tensor,
                                      DataType dtype) {
141
  auto& pool = phi::DeviceContextPool::Instance();
142 143 144 145

  VLOG(3) << "DataTypeTransform src_dtype: " << tensor.dtype()
          << " dst_dtype: " << dtype;

146 147
  DefaultAllocator alloc(tensor.place());
  phi::DenseTensor out(&alloc, {dtype, tensor.dims(), tensor.layout()});
148

149
  if (tensor.place().GetType() == phi::AllocationType::CPU) {
150
    auto* dev_ctx = static_cast<phi::CPUContext*>(pool.Get(tensor.place()));
151
    return CastDataType(*dev_ctx, tensor, dtype);
152
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
153
  } else if (tensor.place().GetType() == phi::AllocationType::GPU) {
154
    auto* dev_ctx = static_cast<phi::GPUContext*>(pool.Get(tensor.place()));
155
    return CastDataType(*dev_ctx, tensor, dtype);
156 157
#endif
  } else {
158
    PADDLE_THROW(phi::errors::Unimplemented(
159 160 161 162 163
        "Place type is not supported when casting data type."));
  }
  return out;
}

164 165 166 167 168
inline phi::DenseTensor TransDataPlace(const phi::DenseTensor& tensor,
                                       Place dst_place) {
  VLOG(3) << "DeviceTransform in, src_place " << tensor.place()
          << " dst_place: " << dst_place;

E
engineer1109 已提交
169
  auto& pool = phi::DeviceContextPool::Instance();
170 171
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
  // NOTE(yy): TransDataPlace should wait for computation of input.
172
  if (tensor.place().GetType() != phi::AllocationType::GPUPINNED) {
173 174 175 176 177 178 179 180 181 182 183 184 185 186
    pool.Get(tensor.place())->Wait();
    pool.Get(dst_place)->Wait();
  }
#endif

  // FIXME(zcd): TransDataPlace is used to transform data from GPU to CPU and
  // the enforced checkings have been done in GetDeviceContext, so the
  // `dev_ctx->Wait()` is necessary. But `dev_ctx->Wait()` will make the program
  // slow, especially when the number of elements is little, for example,
  // the elements of learning rate are one and it's CPU side.
  // One solution is to use a CUDA kernel to complete the copy operation when
  // the transforming is from CPU to GPU and the number of elements is little.
  // But the embarrassment is that this solution this solution makes training
  // slower.
187
  phi::DenseTensor out;
E
engineer1109 已提交
188 189 190 191 192 193 194
  phi::DeviceContext* dev_ctx;
  if (dst_place.GetType() != AllocationType::CPU) {
    dev_ctx = pool.Get(dst_place);
  } else {
    dev_ctx = pool.Get(tensor.place());
  }
  phi::Copy(*dev_ctx, tensor, dst_place, true, &out);
195 196 197
  return out;
}

W
wanghuancoder 已提交
198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245
template <typename Context>
phi::DenseTensor TensorContiguous(const Context& dev_ctx,
                                  const phi::DenseTensor& tensor) {
  phi::DenseTensor dense_out;
  phi::MetaTensor meta_input(tensor);
  phi::MetaTensor meta_out(&dense_out);
  UnchangedInferMeta(meta_input, &meta_out);

  PD_VISIT_ALL_TYPES(tensor.dtype(), "TensorContiguous", ([&] {
                       phi::ContiguousKernel<data_t, Context>(
                           dev_ctx, tensor, &dense_out);
                     }));
  return dense_out;
}

phi::DenseTensor Trans2Contiguous(const phi::DenseTensor& tensor) {
  auto& pool = paddle::platform::DeviceContextPool::Instance();

  VLOG(3) << "Trans2Contiguous...";

  if (tensor.place().GetType() == phi::AllocationType::CPU) {
    auto* dev_ctx = static_cast<phi::CPUContext*>(pool.Get(tensor.place()));
    return TensorContiguous<phi::CPUContext>(*dev_ctx, tensor);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
  } else if (tensor.place().GetType() == phi::AllocationType::GPU) {
    auto* dev_ctx = static_cast<phi::GPUContext*>(pool.Get(tensor.place()));
    return TensorContiguous<phi::GPUContext>(*dev_ctx, tensor);
#endif
#ifdef PADDLE_WITH_XPU
  } else if (tensor.place().GetType() == phi::AllocationType::XPU) {
    auto* dev_ctx = static_cast<phi::XPUContext*>(pool.Get(tensor.place()));
    return TensorContiguous<phi::XPUContext>(*dev_ctx, tensor);
#endif
  } else {
    PADDLE_THROW(phi::errors::Unimplemented(
        "Place type is not supported when casting data type."));
  }

  return tensor;
}

void CheckAndTrans2Contiguous(phi::DenseTensor* tensor) {
  if (!tensor->meta().is_contiguous()) {
    phi::DenseTensor tmp = Trans2Contiguous(*tensor);
    tensor->ShareDataWith(tmp);
  }
}

246
phi::DenseTensor TransformData(phi::DenseTensor* tensor,
247
                               const phi::TensorArgDef& target_args_def,
W
wanghuancoder 已提交
248 249
                               const TransformFlag& transform_flag,
                               bool is_stride_kernel) {
250 251 252
  phi::DenseTensor out = *tensor;
  bool trans_layout = false;
  bool trans_dtype = false;
253

W
wanghuancoder 已提交
254 255 256 257
  if (NeedTransform2Contiguous(is_stride_kernel, out.meta().is_contiguous())) {
    out = Trans2Contiguous(out);
  }

258
  if (NeedTransformLayout(tensor->layout(),
259
                          target_args_def.layout,
260
                          tensor->place(),
261 262
                          transform_flag) &&
      tensor->dims().size() != 1) {
W
wanghuancoder 已提交
263 264 265
    if (NeedTransform2Contiguous(false, out.meta().is_contiguous())) {
      out = Trans2Contiguous(out);
    }
266
    out = TransDataLayout(out, target_args_def.layout);
267
    trans_layout = true;
268 269 270
  }

  if (NeedTransformDataType(
271
          tensor->dtype(), target_args_def.dtype, transform_flag)) {
W
wanghuancoder 已提交
272 273 274
    if (NeedTransform2Contiguous(false, out.meta().is_contiguous())) {
      out = Trans2Contiguous(out);
    }
275
    out = TransDataType(out, target_args_def.dtype);
276
    trans_dtype = true;
277 278 279 280
  }

  if (NeedTransformPlace(
          out.place(), target_args_def.backend, transform_flag)) {
281
    out = TransDataPlace(out, phi::TransToPhiPlace(target_args_def.backend));
282 283 284 285
    if (!trans_layout && !trans_dtype &&
        tensor->place().GetType() == AllocationType::GPUPINNED) {
      tensor->ShareBufferWith(out);
    }
286 287 288 289
  }
  return out;
}

290
std::shared_ptr<phi::DenseTensor> PrepareData(
291
    const Tensor& input,
292
    const phi::TensorArgDef& target_args_def,
W
wanghuancoder 已提交
293 294
    const TransformFlag& transform_flag,
    bool is_stride_kernel) {
295
  const auto& tensor_in = input.impl();
Z
zyfncg 已提交
296 297 298 299 300 301 302 303
  if (tensor_in) {
    phi::DenseTensor& dense_tensor =
        *static_cast<phi::DenseTensor*>(tensor_in.get());
    if (!transform_flag.NeedTransform() || !dense_tensor.initialized() ||
        (!NeedTransformPlace(
             dense_tensor.place(), target_args_def.backend, transform_flag) &&
         !NeedTransformDataType(
             dense_tensor.dtype(), target_args_def.dtype, transform_flag) &&
304
         !NeedTransformLayout(dense_tensor.layout(),
305
                              target_args_def.layout,
306
                              dense_tensor.place(),
W
wanghuancoder 已提交
307 308 309
                              transform_flag) &&
         !NeedTransform2Contiguous(is_stride_kernel,
                                   dense_tensor.meta().is_contiguous()))) {
Z
zyfncg 已提交
310 311
      return std::static_pointer_cast<phi::DenseTensor>(tensor_in);
    }
W
wanghuancoder 已提交
312 313
    phi::DenseTensor out = TransformData(
        &dense_tensor, target_args_def, transform_flag, is_stride_kernel);
Z
zyfncg 已提交
314
    return std::make_shared<phi::DenseTensor>(std::move(out));
315
  }
Z
zyfncg 已提交
316
  return nullptr;
317 318
}

319
paddle::optional<phi::DenseTensor> PrepareData(
320 321
    const paddle::optional<Tensor>& input,
    const phi::TensorArgDef& target_args_def,
W
wanghuancoder 已提交
322 323
    const TransformFlag& transform_flag,
    bool is_stride_kernel) {
324
  if (input) {
W
wanghuancoder 已提交
325 326
    return {*PrepareData(
        *input, target_args_def, transform_flag, is_stride_kernel)};
H
hong 已提交
327
  }
328
  return paddle::none;
H
hong 已提交
329 330
}

331
std::unique_ptr<std::vector<phi::DenseTensor>> PrepareData(
332
    const std::vector<Tensor>& inputs,
333
    const phi::TensorArgDef& target_args_def,
W
wanghuancoder 已提交
334 335
    const TransformFlag& transform_flag,
    bool is_stride_kernel) {
336
  auto pt_tensors = std::make_unique<std::vector<phi::DenseTensor>>();
337 338 339 340
  pt_tensors->reserve(inputs.size());

  for (const auto& input : inputs) {
    const auto& tensor_in = input.impl();
W
wanghuancoder 已提交
341
    auto dense_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(tensor_in);
342 343 344 345 346
    if (!transform_flag.NeedTransform() || !tensor_in->initialized() ||
        (!NeedTransformPlace(
             tensor_in->place(), target_args_def.backend, transform_flag) &&
         !NeedTransformDataType(
             tensor_in->dtype(), target_args_def.dtype, transform_flag) &&
347
         !NeedTransformLayout(tensor_in->layout(),
348
                              target_args_def.layout,
349
                              tensor_in->place(),
W
wanghuancoder 已提交
350 351 352 353
                              transform_flag) &&
         !(dense_tensor &&
           NeedTransform2Contiguous(is_stride_kernel,
                                    dense_tensor->meta().is_contiguous())))) {
354
      pt_tensors->emplace_back(
355
          *std::dynamic_pointer_cast<phi::DenseTensor>(tensor_in));
356 357
    } else {
      pt_tensors->emplace_back(
358
          TransformData((static_cast<phi::DenseTensor*>(tensor_in.get())),
359
                        target_args_def,
W
wanghuancoder 已提交
360 361
                        transform_flag,
                        is_stride_kernel));
362 363 364
    }
  }

365
  return pt_tensors;
366 367
}

368 369 370
paddle::optional<std::vector<phi::DenseTensor>> PrepareData(
    const paddle::optional<std::vector<Tensor>>& inputs,
    const phi::TensorArgDef& target_args_def,
W
wanghuancoder 已提交
371 372
    const TransformFlag& transform_flag,
    bool is_stride_kernel) {
373
  if (inputs) {
W
wanghuancoder 已提交
374 375
    return {*PrepareData(
        *inputs, target_args_def, transform_flag, is_stride_kernel)};
376 377 378 379
  }
  return paddle::none;
}

380 381 382 383 384 385 386 387
std::shared_ptr<phi::SelectedRows> PrepareDataForSelectedRows(
    const Tensor& input,
    const phi::TensorArgDef& target_args_def,
    const TransformFlag& transform_flag) {
  const auto& tensor_in = input.impl();
  if (tensor_in) {
    phi::SelectedRows& selected_rows =
        *static_cast<phi::SelectedRows*>(tensor_in.get());
W
wanghuancoder 已提交
388 389 390 391 392 393
    if ((!transform_flag.NeedTransform() || !selected_rows.initialized() ||
         (!NeedTransformPlace(selected_rows.place(),
                              target_args_def.backend,
                              transform_flag))) &&
        !NeedTransform2Contiguous(
            false, selected_rows.value().meta().is_contiguous())) {
394 395 396 397
      return std::static_pointer_cast<phi::SelectedRows>(tensor_in);
    }

    if (selected_rows.place().GetType() == AllocationType::GPUPINNED) {
W
wanghuancoder 已提交
398 399 400 401 402 403 404 405 406 407 408 409 410
      if (NeedTransform2Contiguous(
              false, selected_rows.value().meta().is_contiguous())) {
        auto dense_out = Trans2Contiguous(selected_rows.value());
        selected_rows.mutable_value()->ShareDataWith(dense_out);
      }
      if (transform_flag.NeedTransform() && selected_rows.initialized() &&
          NeedTransformPlace(
              selected_rows.place(), target_args_def.backend, transform_flag)) {
        auto dense_out =
            TransDataPlace(selected_rows.value(),
                           phi::TransToPhiPlace(target_args_def.backend));
        selected_rows.mutable_value()->ShareBufferWith(dense_out);
      }
411
      return std::static_pointer_cast<phi::SelectedRows>(tensor_in);
W
wanghuancoder 已提交
412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428
    } else {
      auto out_new = std::make_shared<phi::SelectedRows>(
          selected_rows.rows(), selected_rows.height());
      if (NeedTransform2Contiguous(
              false, selected_rows.value().meta().is_contiguous())) {
        auto dense_out = Trans2Contiguous(selected_rows.value());
        *out_new->mutable_value() = dense_out;
      }
      if (transform_flag.NeedTransform() && selected_rows.initialized() &&
          NeedTransformPlace(
              selected_rows.place(), target_args_def.backend, transform_flag)) {
        auto dense_out =
            TransDataPlace(selected_rows.value(),
                           phi::TransToPhiPlace(target_args_def.backend));
        *out_new->mutable_value() = dense_out;
      }
      return out_new;
429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445
    }
  }
  PADDLE_THROW(phi::errors::InvalidArgument(
      "The impl() of input tensor is nullptr, it doesn't support for "
      "selected_rows data transform now."));
}

paddle::optional<phi::SelectedRows> PrepareDataForSelectedRows(
    const paddle::optional<Tensor>& input,
    const phi::TensorArgDef& target_args_def,
    const TransformFlag& transform_flag) {
  if (input) {
    return *PrepareDataForSelectedRows(*input, target_args_def, transform_flag);
  }
  return paddle::none;
}

W
wanghuancoder 已提交
446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544
std::shared_ptr<phi::SparseCooTensor> PrepareDataForSparseCooTensor(
    const Tensor& input) {
  const auto& tensor_in = input.impl();
  if (tensor_in) {
    phi::SparseCooTensor& sparse_tensor =
        *static_cast<phi::SparseCooTensor*>(tensor_in.get());
    if (sparse_tensor.indices().meta().is_contiguous() &&
        sparse_tensor.values().meta().is_contiguous()) {
      return std::static_pointer_cast<phi::SparseCooTensor>(tensor_in);
    }

    if (!sparse_tensor.indices().meta().is_contiguous()) {
      *sparse_tensor.mutable_indices() =
          Trans2Contiguous(sparse_tensor.indices());
    }

    if (!sparse_tensor.values().meta().is_contiguous()) {
      *sparse_tensor.mutable_values() =
          Trans2Contiguous(sparse_tensor.values());
    }
    return std::static_pointer_cast<phi::SparseCooTensor>(tensor_in);
  }
  PADDLE_THROW(phi::errors::InvalidArgument(
      "The impl() of input tensor is nullptr, it doesn't support for "
      "SparseCooTensor data transform now."));
}

paddle::optional<phi::SparseCooTensor> PrepareDataForSparseCooTensor(
    const paddle::optional<Tensor>& input) {
  if (input) {
    return *PrepareDataForSparseCooTensor(*input);
  }
  return paddle::none;
}

std::shared_ptr<phi::SparseCsrTensor> PrepareDataForSparseCsrTensor(
    const Tensor& input) {
  const auto& tensor_in = input.impl();
  if (tensor_in) {
    phi::SparseCsrTensor& sparse_tensor =
        *static_cast<phi::SparseCsrTensor*>(tensor_in.get());
    if (sparse_tensor.crows().meta().is_contiguous() &&
        sparse_tensor.cols().meta().is_contiguous() &&
        sparse_tensor.values().meta().is_contiguous()) {
      return std::static_pointer_cast<phi::SparseCsrTensor>(tensor_in);
    }

    if (!sparse_tensor.crows().meta().is_contiguous()) {
      *sparse_tensor.mutable_crows() = Trans2Contiguous(sparse_tensor.crows());
    }

    if (!sparse_tensor.cols().meta().is_contiguous()) {
      *sparse_tensor.mutable_cols() = Trans2Contiguous(sparse_tensor.cols());
    }

    if (!sparse_tensor.values().meta().is_contiguous()) {
      *sparse_tensor.mutable_values() =
          Trans2Contiguous(sparse_tensor.values());
    }
    return std::static_pointer_cast<phi::SparseCsrTensor>(tensor_in);
  }
  PADDLE_THROW(phi::errors::InvalidArgument(
      "The impl() of input tensor is nullptr, it doesn't support for "
      "SparseCsrTensor data transform now."));
}

paddle::optional<phi::SparseCsrTensor> PrepareDataForSparseCsrTensor(
    const paddle::optional<Tensor>& input) {
  if (input) {
    return *PrepareDataForSparseCsrTensor(*input);
  }
  return paddle::none;
}

std::shared_ptr<phi::DenseTensor> PrepareDataForDenseTensorInSparse(
    const Tensor& input) {
  const auto& tensor_in = input.impl();
  if (tensor_in) {
    phi::DenseTensor& dense_tensor =
        *static_cast<phi::DenseTensor*>(tensor_in.get());
    if (dense_tensor.meta().is_contiguous()) {
      return std::static_pointer_cast<phi::DenseTensor>(tensor_in);
    }

    return std::make_shared<phi::DenseTensor>(
        std::move(Trans2Contiguous(dense_tensor)));
  }
  PADDLE_THROW(phi::errors::InvalidArgument(
      "The impl() of input tensor is nullptr, it doesn't support for "
      "DenseTensor data transform now."));
}

paddle::optional<phi::DenseTensor> PrepareDataForDenseTensorInSparse(
    const paddle::optional<Tensor>& input) {
  if (input) {
    return *PrepareDataForDenseTensorInSparse(*input);
  }
  return paddle::none;
}
545 546 547
void TransDataBackend(const phi::DenseTensor* tensor,
                      Backend target_backend,
                      phi::DenseTensor* out) {
548
  if (tensor && tensor->initialized()) {
549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569
    *out = TransDataPlace(*tensor, phi::TransToPhiPlace(target_backend));
  }
}

void TransDataBackend(const std::vector<phi::DenseTensor*>& tensors,
                      Backend target_backend,
                      std::vector<phi::DenseTensor*> outs) {
  size_t n = tensors.size();
  for (size_t i = 0; i < n; ++i) {
    TransDataBackend(tensors[i], target_backend, outs[i]);
  }
}

void TransDataBackend(const phi::SelectedRows* tensor,
                      Backend target_backend,
                      phi::SelectedRows* out) {
  if (tensor) {
    TransDataBackend(&tensor->value(), target_backend, out->mutable_value());
  }
}

570 571
}  // namespace experimental
}  // namespace paddle