data_transform.cc 24.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/phi/api/lib/data_transform.h"
16

17 18
#include "glog/logging.h"

W
wanghuancoder 已提交
19
#include "paddle/fluid/platform/device_context.h"
20
#include "paddle/phi/api/lib/kernel_dispatch.h"
21
#include "paddle/phi/api/lib/utils/allocator.h"
22
#include "paddle/phi/backends/context_pool.h"
23
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
24
#include "paddle/phi/core/kernel_registry.h"
25
#include "paddle/phi/core/tensor_utils.h"
W
wanghuancoder 已提交
26
#include "paddle/phi/core/visit_type.h"
27
#include "paddle/phi/kernels/cast_kernel.h"
W
wanghuancoder 已提交
28
#include "paddle/phi/kernels/contiguous_kernel.h"
29
#include "paddle/phi/kernels/transfer_layout_kernel.h"
30
#include "paddle/utils/flags.h"
31

32
PD_DECLARE_bool(use_stride_kernel);
33 34 35 36 37 38 39 40 41 42 43 44

namespace paddle {
namespace experimental {

inline bool NeedTransformDataType(const DataType& input,
                                  const DataType& target,
                                  const TransformFlag& transform_flag) {
  return input != target &&
         (transform_flag.need_trans_data_type() ||
          target == DataType::COMPLEX64 || target == DataType::COMPLEX128);
}

45
inline bool NeedTransformLayout(const DataLayout& input,
46
                                const DataLayout& target,
47
                                const phi::Place& place,
48
                                const TransformFlag& transform_flag) {
W
wanghuancoder 已提交
49 50 51 52
  if (FLAGS_use_stride_kernel && target == DataLayout::STRIDED) {
    return false;
  }

53 54 55
  bool ret = transform_flag.need_trans_layout() &&
             (input != DataLayout::ALL_LAYOUT &&
              target != DataLayout::ALL_LAYOUT && input != target);
56
  if (place.GetType() == phi::AllocationType::GPU) {
57 58
    return false;
  }
59 60 61
  return ret;
}

W
wanghuancoder 已提交
62 63 64 65 66
inline bool NeedTransform2Contiguous(bool is_stride_kernel,
                                     bool is_contiguous) {
  return FLAGS_use_stride_kernel && !is_stride_kernel && !is_contiguous;
}

67 68
inline phi::DenseTensor TransDataLayout(const phi::DenseTensor& tensor,
                                        DataLayout layout) {
69
  auto& pool = phi::DeviceContextPool::Instance();
70 71
  VLOG(3) << "DataLayoutTransform src_layout: " << tensor.layout()
          << " dst_layout: " << layout;
72
  if (tensor.place().GetType() == phi::AllocationType::CPU) {
73 74
    auto* dev_ctx = static_cast<phi::CPUContext*>(pool.Get(tensor.place()));
    return phi::TransferLayout(*dev_ctx, tensor, layout);
75
  } else {
76
    PADDLE_THROW(phi::errors::PreconditionNotMet(
77 78
        "Unsupported data layout cast from CPU to GPU."));
  }
79
  return tensor;
80 81 82
}

template <typename Context>
83
phi::DenseTensor CastDataType(const Context& dev_ctx,
84 85
                              const phi::DenseTensor& tensor,
                              DataType dtype) {
86 87
  switch (tensor.dtype()) {
    case DataType::FLOAT32:
88
      return phi::Cast<float>(dev_ctx, tensor, dtype);
89
    case DataType::FLOAT64:
90
      return phi::Cast<double>(dev_ctx, tensor, dtype);
91
    case DataType::INT32:
92
      return phi::Cast<int32_t>(dev_ctx, tensor, dtype);
93
    case DataType::INT64:
94
      return phi::Cast<int64_t>(dev_ctx, tensor, dtype);
95
    case DataType::FLOAT16:
96
      return phi::Cast<phi::dtype::float16>(dev_ctx, tensor, dtype);
97
    case DataType::BFLOAT16:
98
      return phi::Cast<phi::dtype::bfloat16>(dev_ctx, tensor, dtype);
99
    case DataType::BOOL:
100
      return phi::Cast<bool>(dev_ctx, tensor, dtype);
101
    case DataType::INT16:
102
      return phi::Cast<int16_t>(dev_ctx, tensor, dtype);
103
    case DataType::UINT8:
104
      return phi::Cast<uint8_t>(dev_ctx, tensor, dtype);
105
    default:
106
      PADDLE_THROW(phi::errors::Unimplemented(
107 108 109 110 111 112
          "Data type (%s) is not supported when casting data type.",
          tensor.dtype()));
  }
}

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
113
phi::DenseTensor CastDataType(const phi::GPUContext& dev_ctx,
114 115
                              const phi::DenseTensor& tensor,
                              DataType dtype) {
116 117
  switch (tensor.dtype()) {
    case DataType::FLOAT32:
118
      return phi::Cast<float>(dev_ctx, tensor, dtype);
119
    case DataType::FLOAT64:
120
      return phi::Cast<double>(dev_ctx, tensor, dtype);
121
    case DataType::INT32:
122
      return phi::Cast<int32_t>(dev_ctx, tensor, dtype);
123
    case DataType::INT64:
124
      return phi::Cast<int64_t>(dev_ctx, tensor, dtype);
125
    case DataType::FLOAT16:
126
      return phi::Cast<phi::dtype::float16>(dev_ctx, tensor, dtype);
127
    case DataType::BOOL:
128
      return phi::Cast<bool>(dev_ctx, tensor, dtype);
129
    case DataType::INT16:
130
      return phi::Cast<int16_t>(dev_ctx, tensor, dtype);
131
    case DataType::UINT8:
132
      return phi::Cast<uint8_t>(dev_ctx, tensor, dtype);
133
    default:
134
      PADDLE_THROW(phi::errors::Unimplemented(
135 136 137 138 139 140
          "Data type (%s) is not supported when casting data type.",
          tensor.dtype()));
  }
}
#endif

141 142
inline phi::DenseTensor TransDataType(const phi::DenseTensor& tensor,
                                      DataType dtype) {
143
  auto& pool = phi::DeviceContextPool::Instance();
144 145 146 147

  VLOG(3) << "DataTypeTransform src_dtype: " << tensor.dtype()
          << " dst_dtype: " << dtype;

148 149
  DefaultAllocator alloc(tensor.place());
  phi::DenseTensor out(&alloc, {dtype, tensor.dims(), tensor.layout()});
150

151
  if (tensor.place().GetType() == phi::AllocationType::CPU) {
152
    auto* dev_ctx = static_cast<phi::CPUContext*>(pool.Get(tensor.place()));
153
    return CastDataType(*dev_ctx, tensor, dtype);
154
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
155
  } else if (tensor.place().GetType() == phi::AllocationType::GPU) {
156
    auto* dev_ctx = static_cast<phi::GPUContext*>(pool.Get(tensor.place()));
157
    return CastDataType(*dev_ctx, tensor, dtype);
158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
  } else if (tensor.place().GetType() == phi::AllocationType::CUSTOM) {
    phi::DenseTensor out;
    out.Resize(tensor.dims());
    auto* dev_ctx = static_cast<phi::CustomContext*>(pool.Get(tensor.place()));
    auto kernel_result =
        phi::KernelFactory::Instance().SelectKernelOrThrowError(
            "cast",
            {phi::TransToPhiBackend(tensor.place()),
             phi::DataLayout::ALL_LAYOUT,
             tensor.dtype()});
    using kernel_signature = void (*)(const phi::DeviceContext&,
                                      const phi::DenseTensor&,
                                      phi::DataType,
                                      phi::DenseTensor*);
    const auto& kernel = kernel_result.kernel;
    auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
    (*kernel_fn)(*dev_ctx, tensor, dtype, &out);
    return out;
178 179
#endif
  } else {
180
    PADDLE_THROW(phi::errors::Unimplemented(
181 182 183 184 185
        "Place type is not supported when casting data type."));
  }
  return out;
}

186 187 188 189 190
inline phi::DenseTensor TransDataPlace(const phi::DenseTensor& tensor,
                                       Place dst_place) {
  VLOG(3) << "DeviceTransform in, src_place " << tensor.place()
          << " dst_place: " << dst_place;

E
engineer1109 已提交
191
  auto& pool = phi::DeviceContextPool::Instance();
192 193
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
  // NOTE(yy): TransDataPlace should wait for computation of input.
194
  if (tensor.place().GetType() != phi::AllocationType::GPUPINNED) {
195 196 197 198 199 200 201 202 203 204 205 206 207 208
    pool.Get(tensor.place())->Wait();
    pool.Get(dst_place)->Wait();
  }
#endif

  // FIXME(zcd): TransDataPlace is used to transform data from GPU to CPU and
  // the enforced checkings have been done in GetDeviceContext, so the
  // `dev_ctx->Wait()` is necessary. But `dev_ctx->Wait()` will make the program
  // slow, especially when the number of elements is little, for example,
  // the elements of learning rate are one and it's CPU side.
  // One solution is to use a CUDA kernel to complete the copy operation when
  // the transforming is from CPU to GPU and the number of elements is little.
  // But the embarrassment is that this solution this solution makes training
  // slower.
209
  phi::DenseTensor out;
E
engineer1109 已提交
210 211 212 213 214 215 216
  phi::DeviceContext* dev_ctx;
  if (dst_place.GetType() != AllocationType::CPU) {
    dev_ctx = pool.Get(dst_place);
  } else {
    dev_ctx = pool.Get(tensor.place());
  }
  phi::Copy(*dev_ctx, tensor, dst_place, true, &out);
217 218 219
  return out;
}

W
wanghuancoder 已提交
220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267
template <typename Context>
phi::DenseTensor TensorContiguous(const Context& dev_ctx,
                                  const phi::DenseTensor& tensor) {
  phi::DenseTensor dense_out;
  phi::MetaTensor meta_input(tensor);
  phi::MetaTensor meta_out(&dense_out);
  UnchangedInferMeta(meta_input, &meta_out);

  PD_VISIT_ALL_TYPES(tensor.dtype(), "TensorContiguous", ([&] {
                       phi::ContiguousKernel<data_t, Context>(
                           dev_ctx, tensor, &dense_out);
                     }));
  return dense_out;
}

phi::DenseTensor Trans2Contiguous(const phi::DenseTensor& tensor) {
  auto& pool = paddle::platform::DeviceContextPool::Instance();

  VLOG(3) << "Trans2Contiguous...";

  if (tensor.place().GetType() == phi::AllocationType::CPU) {
    auto* dev_ctx = static_cast<phi::CPUContext*>(pool.Get(tensor.place()));
    return TensorContiguous<phi::CPUContext>(*dev_ctx, tensor);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
  } else if (tensor.place().GetType() == phi::AllocationType::GPU) {
    auto* dev_ctx = static_cast<phi::GPUContext*>(pool.Get(tensor.place()));
    return TensorContiguous<phi::GPUContext>(*dev_ctx, tensor);
#endif
#ifdef PADDLE_WITH_XPU
  } else if (tensor.place().GetType() == phi::AllocationType::XPU) {
    auto* dev_ctx = static_cast<phi::XPUContext*>(pool.Get(tensor.place()));
    return TensorContiguous<phi::XPUContext>(*dev_ctx, tensor);
#endif
  } else {
    PADDLE_THROW(phi::errors::Unimplemented(
        "Place type is not supported when casting data type."));
  }

  return tensor;
}

void CheckAndTrans2Contiguous(phi::DenseTensor* tensor) {
  if (!tensor->meta().is_contiguous()) {
    phi::DenseTensor tmp = Trans2Contiguous(*tensor);
    tensor->ShareDataWith(tmp);
  }
}

268
phi::DenseTensor TransformData(const phi::DenseTensor& tensor,
269
                               const phi::TensorArgDef& target_args_def,
W
wanghuancoder 已提交
270 271
                               const TransformFlag& transform_flag,
                               bool is_stride_kernel) {
272
  phi::DenseTensor out = tensor;
273 274
  bool trans_layout = false;
  bool trans_dtype = false;
275

W
wanghuancoder 已提交
276 277 278 279
  if (NeedTransform2Contiguous(is_stride_kernel, out.meta().is_contiguous())) {
    out = Trans2Contiguous(out);
  }

280
  if (NeedTransformLayout(tensor.layout(),
281
                          target_args_def.layout,
282
                          tensor.place(),
283
                          transform_flag) &&
284
      tensor.dims().size() != 1) {
W
wanghuancoder 已提交
285 286 287
    if (NeedTransform2Contiguous(false, out.meta().is_contiguous())) {
      out = Trans2Contiguous(out);
    }
288
    out = TransDataLayout(out, target_args_def.layout);
289
    trans_layout = true;
290 291 292
  }

  if (NeedTransformDataType(
293
          tensor.dtype(), target_args_def.dtype, transform_flag)) {
W
wanghuancoder 已提交
294 295 296
    if (NeedTransform2Contiguous(false, out.meta().is_contiguous())) {
      out = Trans2Contiguous(out);
    }
297
    out = TransDataType(out, target_args_def.dtype);
298
    trans_dtype = true;
299 300 301 302
  }

  if (NeedTransformPlace(
          out.place(), target_args_def.backend, transform_flag)) {
303
    out = TransDataPlace(out, phi::TransToPhiPlace(target_args_def.backend));
304
    if (!trans_layout && !trans_dtype &&
305 306 307 308 309 310 311 312
        tensor.place().GetType() == AllocationType::GPUPINNED) {
      // Sharing buffer on GPUPINNED place is a special case due to historical
      // reasons, and it should not be implemented in this way from a
      // reasonable point of view, but because the performance of the previous
      // model depends on the inplace operation here, the model performance
      // will deteriorate after reverting to non-place impl, so it needs to be
      // retained here and need to use `const_cast`
      const_cast<phi::DenseTensor&>(tensor).ShareBufferWith(out);
313
    }
314 315 316 317
  }
  return out;
}

318
std::shared_ptr<phi::DenseTensor> PrepareData(
319
    const Tensor& input,
320
    const phi::TensorArgDef& target_args_def,
W
wanghuancoder 已提交
321 322
    const TransformFlag& transform_flag,
    bool is_stride_kernel) {
323
  const auto& tensor_in = input.impl();
Z
zyfncg 已提交
324 325 326 327 328 329 330 331
  if (tensor_in) {
    phi::DenseTensor& dense_tensor =
        *static_cast<phi::DenseTensor*>(tensor_in.get());
    if (!transform_flag.NeedTransform() || !dense_tensor.initialized() ||
        (!NeedTransformPlace(
             dense_tensor.place(), target_args_def.backend, transform_flag) &&
         !NeedTransformDataType(
             dense_tensor.dtype(), target_args_def.dtype, transform_flag) &&
332
         !NeedTransformLayout(dense_tensor.layout(),
333
                              target_args_def.layout,
334
                              dense_tensor.place(),
W
wanghuancoder 已提交
335 336 337
                              transform_flag) &&
         !NeedTransform2Contiguous(is_stride_kernel,
                                   dense_tensor.meta().is_contiguous()))) {
Z
zyfncg 已提交
338 339
      return std::static_pointer_cast<phi::DenseTensor>(tensor_in);
    }
W
wanghuancoder 已提交
340
    phi::DenseTensor out = TransformData(
341
        dense_tensor, target_args_def, transform_flag, is_stride_kernel);
Z
zyfncg 已提交
342
    return std::make_shared<phi::DenseTensor>(std::move(out));
343
  }
Z
zyfncg 已提交
344
  return nullptr;
345 346
}

347
paddle::optional<phi::DenseTensor> PrepareData(
348 349
    const paddle::optional<Tensor>& input,
    const phi::TensorArgDef& target_args_def,
W
wanghuancoder 已提交
350 351
    const TransformFlag& transform_flag,
    bool is_stride_kernel) {
352
  if (input) {
W
wanghuancoder 已提交
353 354
    return {*PrepareData(
        *input, target_args_def, transform_flag, is_stride_kernel)};
H
hong 已提交
355
  }
356
  return paddle::none;
H
hong 已提交
357 358
}

359
std::unique_ptr<std::vector<phi::DenseTensor>> PrepareData(
360
    const std::vector<Tensor>& inputs,
361
    const phi::TensorArgDef& target_args_def,
W
wanghuancoder 已提交
362 363
    const TransformFlag& transform_flag,
    bool is_stride_kernel) {
364
  auto pt_tensors = std::make_unique<std::vector<phi::DenseTensor>>();
365 366 367 368
  pt_tensors->reserve(inputs.size());

  for (const auto& input : inputs) {
    const auto& tensor_in = input.impl();
W
wanghuancoder 已提交
369
    auto dense_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(tensor_in);
370 371 372 373 374
    if (!transform_flag.NeedTransform() || !tensor_in->initialized() ||
        (!NeedTransformPlace(
             tensor_in->place(), target_args_def.backend, transform_flag) &&
         !NeedTransformDataType(
             tensor_in->dtype(), target_args_def.dtype, transform_flag) &&
375
         !NeedTransformLayout(tensor_in->layout(),
376
                              target_args_def.layout,
377
                              tensor_in->place(),
W
wanghuancoder 已提交
378 379 380 381
                              transform_flag) &&
         !(dense_tensor &&
           NeedTransform2Contiguous(is_stride_kernel,
                                    dense_tensor->meta().is_contiguous())))) {
382
      pt_tensors->emplace_back(
383
          *std::dynamic_pointer_cast<phi::DenseTensor>(tensor_in));
384 385
    } else {
      pt_tensors->emplace_back(
386
          TransformData(*(static_cast<phi::DenseTensor*>(tensor_in.get())),
387
                        target_args_def,
W
wanghuancoder 已提交
388 389
                        transform_flag,
                        is_stride_kernel));
390 391 392
    }
  }

393
  return pt_tensors;
394 395
}

396 397 398
paddle::optional<std::vector<phi::DenseTensor>> PrepareData(
    const paddle::optional<std::vector<Tensor>>& inputs,
    const phi::TensorArgDef& target_args_def,
W
wanghuancoder 已提交
399 400
    const TransformFlag& transform_flag,
    bool is_stride_kernel) {
401
  if (inputs) {
W
wanghuancoder 已提交
402 403
    return {*PrepareData(
        *inputs, target_args_def, transform_flag, is_stride_kernel)};
404 405 406 407
  }
  return paddle::none;
}

408 409 410 411 412 413 414 415
std::shared_ptr<phi::SelectedRows> PrepareDataForSelectedRows(
    const Tensor& input,
    const phi::TensorArgDef& target_args_def,
    const TransformFlag& transform_flag) {
  const auto& tensor_in = input.impl();
  if (tensor_in) {
    phi::SelectedRows& selected_rows =
        *static_cast<phi::SelectedRows*>(tensor_in.get());
W
wanghuancoder 已提交
416 417 418 419 420 421
    if ((!transform_flag.NeedTransform() || !selected_rows.initialized() ||
         (!NeedTransformPlace(selected_rows.place(),
                              target_args_def.backend,
                              transform_flag))) &&
        !NeedTransform2Contiguous(
            false, selected_rows.value().meta().is_contiguous())) {
422 423 424 425
      return std::static_pointer_cast<phi::SelectedRows>(tensor_in);
    }

    if (selected_rows.place().GetType() == AllocationType::GPUPINNED) {
W
wanghuancoder 已提交
426 427 428 429 430 431 432 433 434 435 436 437 438
      if (NeedTransform2Contiguous(
              false, selected_rows.value().meta().is_contiguous())) {
        auto dense_out = Trans2Contiguous(selected_rows.value());
        selected_rows.mutable_value()->ShareDataWith(dense_out);
      }
      if (transform_flag.NeedTransform() && selected_rows.initialized() &&
          NeedTransformPlace(
              selected_rows.place(), target_args_def.backend, transform_flag)) {
        auto dense_out =
            TransDataPlace(selected_rows.value(),
                           phi::TransToPhiPlace(target_args_def.backend));
        selected_rows.mutable_value()->ShareBufferWith(dense_out);
      }
439
      return std::static_pointer_cast<phi::SelectedRows>(tensor_in);
W
wanghuancoder 已提交
440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456
    } else {
      auto out_new = std::make_shared<phi::SelectedRows>(
          selected_rows.rows(), selected_rows.height());
      if (NeedTransform2Contiguous(
              false, selected_rows.value().meta().is_contiguous())) {
        auto dense_out = Trans2Contiguous(selected_rows.value());
        *out_new->mutable_value() = dense_out;
      }
      if (transform_flag.NeedTransform() && selected_rows.initialized() &&
          NeedTransformPlace(
              selected_rows.place(), target_args_def.backend, transform_flag)) {
        auto dense_out =
            TransDataPlace(selected_rows.value(),
                           phi::TransToPhiPlace(target_args_def.backend));
        *out_new->mutable_value() = dense_out;
      }
      return out_new;
457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473
    }
  }
  PADDLE_THROW(phi::errors::InvalidArgument(
      "The impl() of input tensor is nullptr, it doesn't support for "
      "selected_rows data transform now."));
}

paddle::optional<phi::SelectedRows> PrepareDataForSelectedRows(
    const paddle::optional<Tensor>& input,
    const phi::TensorArgDef& target_args_def,
    const TransformFlag& transform_flag) {
  if (input) {
    return *PrepareDataForSelectedRows(*input, target_args_def, transform_flag);
  }
  return paddle::none;
}

W
wanghuancoder 已提交
474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572
std::shared_ptr<phi::SparseCooTensor> PrepareDataForSparseCooTensor(
    const Tensor& input) {
  const auto& tensor_in = input.impl();
  if (tensor_in) {
    phi::SparseCooTensor& sparse_tensor =
        *static_cast<phi::SparseCooTensor*>(tensor_in.get());
    if (sparse_tensor.indices().meta().is_contiguous() &&
        sparse_tensor.values().meta().is_contiguous()) {
      return std::static_pointer_cast<phi::SparseCooTensor>(tensor_in);
    }

    if (!sparse_tensor.indices().meta().is_contiguous()) {
      *sparse_tensor.mutable_indices() =
          Trans2Contiguous(sparse_tensor.indices());
    }

    if (!sparse_tensor.values().meta().is_contiguous()) {
      *sparse_tensor.mutable_values() =
          Trans2Contiguous(sparse_tensor.values());
    }
    return std::static_pointer_cast<phi::SparseCooTensor>(tensor_in);
  }
  PADDLE_THROW(phi::errors::InvalidArgument(
      "The impl() of input tensor is nullptr, it doesn't support for "
      "SparseCooTensor data transform now."));
}

paddle::optional<phi::SparseCooTensor> PrepareDataForSparseCooTensor(
    const paddle::optional<Tensor>& input) {
  if (input) {
    return *PrepareDataForSparseCooTensor(*input);
  }
  return paddle::none;
}

std::shared_ptr<phi::SparseCsrTensor> PrepareDataForSparseCsrTensor(
    const Tensor& input) {
  const auto& tensor_in = input.impl();
  if (tensor_in) {
    phi::SparseCsrTensor& sparse_tensor =
        *static_cast<phi::SparseCsrTensor*>(tensor_in.get());
    if (sparse_tensor.crows().meta().is_contiguous() &&
        sparse_tensor.cols().meta().is_contiguous() &&
        sparse_tensor.values().meta().is_contiguous()) {
      return std::static_pointer_cast<phi::SparseCsrTensor>(tensor_in);
    }

    if (!sparse_tensor.crows().meta().is_contiguous()) {
      *sparse_tensor.mutable_crows() = Trans2Contiguous(sparse_tensor.crows());
    }

    if (!sparse_tensor.cols().meta().is_contiguous()) {
      *sparse_tensor.mutable_cols() = Trans2Contiguous(sparse_tensor.cols());
    }

    if (!sparse_tensor.values().meta().is_contiguous()) {
      *sparse_tensor.mutable_values() =
          Trans2Contiguous(sparse_tensor.values());
    }
    return std::static_pointer_cast<phi::SparseCsrTensor>(tensor_in);
  }
  PADDLE_THROW(phi::errors::InvalidArgument(
      "The impl() of input tensor is nullptr, it doesn't support for "
      "SparseCsrTensor data transform now."));
}

paddle::optional<phi::SparseCsrTensor> PrepareDataForSparseCsrTensor(
    const paddle::optional<Tensor>& input) {
  if (input) {
    return *PrepareDataForSparseCsrTensor(*input);
  }
  return paddle::none;
}

std::shared_ptr<phi::DenseTensor> PrepareDataForDenseTensorInSparse(
    const Tensor& input) {
  const auto& tensor_in = input.impl();
  if (tensor_in) {
    phi::DenseTensor& dense_tensor =
        *static_cast<phi::DenseTensor*>(tensor_in.get());
    if (dense_tensor.meta().is_contiguous()) {
      return std::static_pointer_cast<phi::DenseTensor>(tensor_in);
    }

    return std::make_shared<phi::DenseTensor>(
        std::move(Trans2Contiguous(dense_tensor)));
  }
  PADDLE_THROW(phi::errors::InvalidArgument(
      "The impl() of input tensor is nullptr, it doesn't support for "
      "DenseTensor data transform now."));
}

paddle::optional<phi::DenseTensor> PrepareDataForDenseTensorInSparse(
    const paddle::optional<Tensor>& input) {
  if (input) {
    return *PrepareDataForDenseTensorInSparse(*input);
  }
  return paddle::none;
}
573 574 575
void TransDataBackend(const phi::DenseTensor* tensor,
                      Backend target_backend,
                      phi::DenseTensor* out) {
576
  if (tensor && tensor->initialized()) {
577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597
    *out = TransDataPlace(*tensor, phi::TransToPhiPlace(target_backend));
  }
}

void TransDataBackend(const std::vector<phi::DenseTensor*>& tensors,
                      Backend target_backend,
                      std::vector<phi::DenseTensor*> outs) {
  size_t n = tensors.size();
  for (size_t i = 0; i < n; ++i) {
    TransDataBackend(tensors[i], target_backend, outs[i]);
  }
}

void TransDataBackend(const phi::SelectedRows* tensor,
                      Backend target_backend,
                      phi::SelectedRows* out) {
  if (tensor) {
    TransDataBackend(&tensor->value(), target_backend, out->mutable_value());
  }
}

598 599 600 601 602 603 604 605 606 607 608
/* ------------------ for auto parallel ----------------------- */

std::shared_ptr<phi::distributed::DistTensor> PrepareDataForDistTensor(
    const Tensor& input,
    const phi::TensorArgDef& target_args_def,
    const TransformFlag& transform_flag,
    bool is_stride_kernel) {
  const auto& tensor_in = input.impl();
  if (tensor_in) {
    phi::distributed::DistTensor* dist_tensor =
        static_cast<phi::distributed::DistTensor*>(tensor_in.get());
609
    const phi::DenseTensor& dense_tensor = dist_tensor->value();
610 611 612 613 614 615 616 617 618 619 620 621 622 623
    if (!transform_flag.NeedTransform() || !dense_tensor.initialized() ||
        (!NeedTransformPlace(
             dense_tensor.place(), target_args_def.backend, transform_flag) &&
         !NeedTransformDataType(
             dense_tensor.dtype(), target_args_def.dtype, transform_flag) &&
         !NeedTransformLayout(dense_tensor.layout(),
                              target_args_def.layout,
                              dense_tensor.place(),
                              transform_flag) &&
         !NeedTransform2Contiguous(is_stride_kernel,
                                   dense_tensor.meta().is_contiguous()))) {
      return std::static_pointer_cast<phi::distributed::DistTensor>(tensor_in);
    }
    phi::DenseTensor out = TransformData(
624
        dense_tensor, target_args_def, transform_flag, is_stride_kernel);
625 626 627 628 629
    // TODO(chenweihang): The global meta in DistTensor is not changed,
    // but the local meta in DenseTensor maybe changed, such as layout
    // change(NCHW->NHWC), so the new DistTensor's meta maybe not unified.
    VLOG(6) << "PrepareDataForDistTensor return transformed dist tensor";
    return std::make_shared<phi::distributed::DistTensor>(
630
        out, dist_tensor->dist_attr());
631 632 633 634
  }
  return nullptr;
}

635 636
}  // namespace experimental
}  // namespace paddle