data_transform.cc 27.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/phi/api/lib/data_transform.h"
16

17 18
#include "glog/logging.h"

W
wanghuancoder 已提交
19
#include "paddle/fluid/platform/device_context.h"
20
#include "paddle/phi/api/lib/kernel_dispatch.h"
21
#include "paddle/phi/api/lib/utils/allocator.h"
22
#include "paddle/phi/backends/context_pool.h"
23
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
24 25
#include "paddle/phi/core/distributed/auto_parallel/reshard_function.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h"
26
#include "paddle/phi/core/flags.h"
27
#include "paddle/phi/core/kernel_registry.h"
28
#include "paddle/phi/core/tensor_utils.h"
W
wanghuancoder 已提交
29
#include "paddle/phi/core/visit_type.h"
30
#include "paddle/phi/kernels/cast_kernel.h"
W
wanghuancoder 已提交
31
#include "paddle/phi/kernels/contiguous_kernel.h"
32
#include "paddle/phi/kernels/transfer_layout_kernel.h"
33

34
PHI_DECLARE_bool(use_stride_kernel);
35 36 37 38 39 40 41 42 43 44 45 46

namespace paddle {
namespace experimental {

inline bool NeedTransformDataType(const DataType& input,
                                  const DataType& target,
                                  const TransformFlag& transform_flag) {
  return input != target &&
         (transform_flag.need_trans_data_type() ||
          target == DataType::COMPLEX64 || target == DataType::COMPLEX128);
}

47
inline bool NeedTransformLayout(const DataLayout& input,
48
                                const DataLayout& target,
49
                                const phi::Place& place,
50
                                const TransformFlag& transform_flag) {
W
wanghuancoder 已提交
51 52 53 54
  if (FLAGS_use_stride_kernel && target == DataLayout::STRIDED) {
    return false;
  }

55 56 57
  bool ret = transform_flag.need_trans_layout() &&
             (input != DataLayout::ALL_LAYOUT &&
              target != DataLayout::ALL_LAYOUT && input != target);
58
  if (place.GetType() == phi::AllocationType::GPU) {
59 60
    return false;
  }
61 62 63
  return ret;
}

W
wanghuancoder 已提交
64 65 66 67 68
inline bool NeedTransform2Contiguous(bool is_stride_kernel,
                                     bool is_contiguous) {
  return FLAGS_use_stride_kernel && !is_stride_kernel && !is_contiguous;
}

69 70
inline phi::DenseTensor TransDataLayout(const phi::DenseTensor& tensor,
                                        DataLayout layout) {
71
  auto& pool = phi::DeviceContextPool::Instance();
72 73
  VLOG(3) << "DataLayoutTransform src_layout: " << tensor.layout()
          << " dst_layout: " << layout;
74
  if (tensor.place().GetType() == phi::AllocationType::CPU) {
75 76
    auto* dev_ctx = static_cast<phi::CPUContext*>(pool.Get(tensor.place()));
    return phi::TransferLayout(*dev_ctx, tensor, layout);
77
  } else {
78
    PADDLE_THROW(phi::errors::PreconditionNotMet(
79 80
        "Unsupported data layout cast from CPU to GPU."));
  }
81
  return tensor;
82 83 84
}

template <typename Context>
85
phi::DenseTensor CastDataType(const Context& dev_ctx,
86 87
                              const phi::DenseTensor& tensor,
                              DataType dtype) {
88 89
  switch (tensor.dtype()) {
    case DataType::FLOAT32:
90
      return phi::Cast<float>(dev_ctx, tensor, dtype);
91
    case DataType::FLOAT64:
92
      return phi::Cast<double>(dev_ctx, tensor, dtype);
93
    case DataType::INT32:
94
      return phi::Cast<int32_t>(dev_ctx, tensor, dtype);
95
    case DataType::INT64:
96
      return phi::Cast<int64_t>(dev_ctx, tensor, dtype);
97
    case DataType::FLOAT16:
98
      return phi::Cast<phi::dtype::float16>(dev_ctx, tensor, dtype);
99
    case DataType::BFLOAT16:
100
      return phi::Cast<phi::dtype::bfloat16>(dev_ctx, tensor, dtype);
101
    case DataType::BOOL:
102
      return phi::Cast<bool>(dev_ctx, tensor, dtype);
103
    case DataType::INT16:
104
      return phi::Cast<int16_t>(dev_ctx, tensor, dtype);
105
    case DataType::UINT8:
106
      return phi::Cast<uint8_t>(dev_ctx, tensor, dtype);
107
    default:
108
      PADDLE_THROW(phi::errors::Unimplemented(
109 110 111 112 113 114
          "Data type (%s) is not supported when casting data type.",
          tensor.dtype()));
  }
}

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
115
phi::DenseTensor CastDataType(const phi::GPUContext& dev_ctx,
116 117
                              const phi::DenseTensor& tensor,
                              DataType dtype) {
118 119
  switch (tensor.dtype()) {
    case DataType::FLOAT32:
120
      return phi::Cast<float>(dev_ctx, tensor, dtype);
121
    case DataType::FLOAT64:
122
      return phi::Cast<double>(dev_ctx, tensor, dtype);
123
    case DataType::INT32:
124
      return phi::Cast<int32_t>(dev_ctx, tensor, dtype);
125
    case DataType::INT64:
126
      return phi::Cast<int64_t>(dev_ctx, tensor, dtype);
127
    case DataType::FLOAT16:
128
      return phi::Cast<phi::dtype::float16>(dev_ctx, tensor, dtype);
129
    case DataType::BOOL:
130
      return phi::Cast<bool>(dev_ctx, tensor, dtype);
131
    case DataType::INT16:
132
      return phi::Cast<int16_t>(dev_ctx, tensor, dtype);
133
    case DataType::UINT8:
134
      return phi::Cast<uint8_t>(dev_ctx, tensor, dtype);
135
    default:
136
      PADDLE_THROW(phi::errors::Unimplemented(
137 138 139 140 141 142
          "Data type (%s) is not supported when casting data type.",
          tensor.dtype()));
  }
}
#endif

143 144
inline phi::DenseTensor TransDataType(const phi::DenseTensor& tensor,
                                      DataType dtype) {
145
  auto& pool = phi::DeviceContextPool::Instance();
146 147 148 149

  VLOG(3) << "DataTypeTransform src_dtype: " << tensor.dtype()
          << " dst_dtype: " << dtype;

150 151
  DefaultAllocator alloc(tensor.place());
  phi::DenseTensor out(&alloc, {dtype, tensor.dims(), tensor.layout()});
152

153
  if (tensor.place().GetType() == phi::AllocationType::CPU) {
154
    auto* dev_ctx = static_cast<phi::CPUContext*>(pool.Get(tensor.place()));
155
    return CastDataType(*dev_ctx, tensor, dtype);
156
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
157
  } else if (tensor.place().GetType() == phi::AllocationType::GPU) {
158
    auto* dev_ctx = static_cast<phi::GPUContext*>(pool.Get(tensor.place()));
159
    return CastDataType(*dev_ctx, tensor, dtype);
160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
  } else if (tensor.place().GetType() == phi::AllocationType::CUSTOM) {
    phi::DenseTensor out;
    out.Resize(tensor.dims());
    auto* dev_ctx = static_cast<phi::CustomContext*>(pool.Get(tensor.place()));
    auto kernel_result =
        phi::KernelFactory::Instance().SelectKernelOrThrowError(
            "cast",
            {phi::TransToPhiBackend(tensor.place()),
             phi::DataLayout::ALL_LAYOUT,
             tensor.dtype()});
    using kernel_signature = void (*)(const phi::DeviceContext&,
                                      const phi::DenseTensor&,
                                      phi::DataType,
                                      phi::DenseTensor*);
    const auto& kernel = kernel_result.kernel;
    auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
    (*kernel_fn)(*dev_ctx, tensor, dtype, &out);
    return out;
180 181
#endif
  } else {
182
    PADDLE_THROW(phi::errors::Unimplemented(
183 184 185 186 187
        "Place type is not supported when casting data type."));
  }
  return out;
}

188 189 190 191 192
inline phi::DenseTensor TransDataPlace(const phi::DenseTensor& tensor,
                                       Place dst_place) {
  VLOG(3) << "DeviceTransform in, src_place " << tensor.place()
          << " dst_place: " << dst_place;

E
engineer1109 已提交
193
  auto& pool = phi::DeviceContextPool::Instance();
194 195
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
  // NOTE(yy): TransDataPlace should wait for computation of input.
196
  if (tensor.place().GetType() != phi::AllocationType::GPUPINNED) {
197 198 199 200 201 202 203 204 205 206 207 208 209 210
    pool.Get(tensor.place())->Wait();
    pool.Get(dst_place)->Wait();
  }
#endif

  // FIXME(zcd): TransDataPlace is used to transform data from GPU to CPU and
  // the enforced checkings have been done in GetDeviceContext, so the
  // `dev_ctx->Wait()` is necessary. But `dev_ctx->Wait()` will make the program
  // slow, especially when the number of elements is little, for example,
  // the elements of learning rate are one and it's CPU side.
  // One solution is to use a CUDA kernel to complete the copy operation when
  // the transforming is from CPU to GPU and the number of elements is little.
  // But the embarrassment is that this solution this solution makes training
  // slower.
211
  phi::DenseTensor out;
E
engineer1109 已提交
212 213 214 215 216 217 218
  phi::DeviceContext* dev_ctx;
  if (dst_place.GetType() != AllocationType::CPU) {
    dev_ctx = pool.Get(dst_place);
  } else {
    dev_ctx = pool.Get(tensor.place());
  }
  phi::Copy(*dev_ctx, tensor, dst_place, true, &out);
219 220 221
  return out;
}

W
wanghuancoder 已提交
222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269
template <typename Context>
phi::DenseTensor TensorContiguous(const Context& dev_ctx,
                                  const phi::DenseTensor& tensor) {
  phi::DenseTensor dense_out;
  phi::MetaTensor meta_input(tensor);
  phi::MetaTensor meta_out(&dense_out);
  UnchangedInferMeta(meta_input, &meta_out);

  PD_VISIT_ALL_TYPES(tensor.dtype(), "TensorContiguous", ([&] {
                       phi::ContiguousKernel<data_t, Context>(
                           dev_ctx, tensor, &dense_out);
                     }));
  return dense_out;
}

phi::DenseTensor Trans2Contiguous(const phi::DenseTensor& tensor) {
  auto& pool = paddle::platform::DeviceContextPool::Instance();

  VLOG(3) << "Trans2Contiguous...";

  if (tensor.place().GetType() == phi::AllocationType::CPU) {
    auto* dev_ctx = static_cast<phi::CPUContext*>(pool.Get(tensor.place()));
    return TensorContiguous<phi::CPUContext>(*dev_ctx, tensor);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
  } else if (tensor.place().GetType() == phi::AllocationType::GPU) {
    auto* dev_ctx = static_cast<phi::GPUContext*>(pool.Get(tensor.place()));
    return TensorContiguous<phi::GPUContext>(*dev_ctx, tensor);
#endif
#ifdef PADDLE_WITH_XPU
  } else if (tensor.place().GetType() == phi::AllocationType::XPU) {
    auto* dev_ctx = static_cast<phi::XPUContext*>(pool.Get(tensor.place()));
    return TensorContiguous<phi::XPUContext>(*dev_ctx, tensor);
#endif
  } else {
    PADDLE_THROW(phi::errors::Unimplemented(
        "Place type is not supported when casting data type."));
  }

  return tensor;
}

void CheckAndTrans2Contiguous(phi::DenseTensor* tensor) {
  if (!tensor->meta().is_contiguous()) {
    phi::DenseTensor tmp = Trans2Contiguous(*tensor);
    tensor->ShareDataWith(tmp);
  }
}

270
phi::DenseTensor TransformData(const phi::DenseTensor& tensor,
271
                               const phi::TensorArgDef& target_args_def,
W
wanghuancoder 已提交
272 273
                               const TransformFlag& transform_flag,
                               bool is_stride_kernel) {
274
  phi::DenseTensor out = tensor;
275 276
  bool trans_layout = false;
  bool trans_dtype = false;
277

W
wanghuancoder 已提交
278 279 280 281
  if (NeedTransform2Contiguous(is_stride_kernel, out.meta().is_contiguous())) {
    out = Trans2Contiguous(out);
  }

282
  if (NeedTransformLayout(tensor.layout(),
283
                          target_args_def.layout,
284
                          tensor.place(),
285
                          transform_flag) &&
286
      tensor.dims().size() != 1) {
W
wanghuancoder 已提交
287 288 289
    if (NeedTransform2Contiguous(false, out.meta().is_contiguous())) {
      out = Trans2Contiguous(out);
    }
290
    out = TransDataLayout(out, target_args_def.layout);
291
    trans_layout = true;
292 293 294
  }

  if (NeedTransformDataType(
295
          tensor.dtype(), target_args_def.dtype, transform_flag)) {
W
wanghuancoder 已提交
296 297 298
    if (NeedTransform2Contiguous(false, out.meta().is_contiguous())) {
      out = Trans2Contiguous(out);
    }
299
    out = TransDataType(out, target_args_def.dtype);
300
    trans_dtype = true;
301 302 303 304
  }

  if (NeedTransformPlace(
          out.place(), target_args_def.backend, transform_flag)) {
305
    out = TransDataPlace(out, phi::TransToPhiPlace(target_args_def.backend));
306
    if (!trans_layout && !trans_dtype &&
307 308 309 310 311 312 313 314
        tensor.place().GetType() == AllocationType::GPUPINNED) {
      // Sharing buffer on GPUPINNED place is a special case due to historical
      // reasons, and it should not be implemented in this way from a
      // reasonable point of view, but because the performance of the previous
      // model depends on the inplace operation here, the model performance
      // will deteriorate after reverting to non-place impl, so it needs to be
      // retained here and need to use `const_cast`
      const_cast<phi::DenseTensor&>(tensor).ShareBufferWith(out);
315
    }
316 317 318 319
  }
  return out;
}

320
std::shared_ptr<phi::DenseTensor> PrepareData(
321
    const Tensor& input,
322
    const phi::TensorArgDef& target_args_def,
W
wanghuancoder 已提交
323 324
    const TransformFlag& transform_flag,
    bool is_stride_kernel) {
325
  const auto& tensor_in = input.impl();
Z
zyfncg 已提交
326 327 328 329 330 331 332 333
  if (tensor_in) {
    phi::DenseTensor& dense_tensor =
        *static_cast<phi::DenseTensor*>(tensor_in.get());
    if (!transform_flag.NeedTransform() || !dense_tensor.initialized() ||
        (!NeedTransformPlace(
             dense_tensor.place(), target_args_def.backend, transform_flag) &&
         !NeedTransformDataType(
             dense_tensor.dtype(), target_args_def.dtype, transform_flag) &&
334
         !NeedTransformLayout(dense_tensor.layout(),
335
                              target_args_def.layout,
336
                              dense_tensor.place(),
W
wanghuancoder 已提交
337 338 339
                              transform_flag) &&
         !NeedTransform2Contiguous(is_stride_kernel,
                                   dense_tensor.meta().is_contiguous()))) {
Z
zyfncg 已提交
340 341
      return std::static_pointer_cast<phi::DenseTensor>(tensor_in);
    }
W
wanghuancoder 已提交
342
    phi::DenseTensor out = TransformData(
343
        dense_tensor, target_args_def, transform_flag, is_stride_kernel);
Z
zyfncg 已提交
344
    return std::make_shared<phi::DenseTensor>(std::move(out));
345
  }
Z
zyfncg 已提交
346
  return nullptr;
347 348
}

349
paddle::optional<phi::DenseTensor> PrepareData(
350 351
    const paddle::optional<Tensor>& input,
    const phi::TensorArgDef& target_args_def,
W
wanghuancoder 已提交
352 353
    const TransformFlag& transform_flag,
    bool is_stride_kernel) {
354
  if (input) {
W
wanghuancoder 已提交
355 356
    return {*PrepareData(
        *input, target_args_def, transform_flag, is_stride_kernel)};
H
hong 已提交
357
  }
358
  return paddle::none;
H
hong 已提交
359 360
}

361
std::unique_ptr<std::vector<phi::DenseTensor>> PrepareData(
362
    const std::vector<Tensor>& inputs,
363
    const phi::TensorArgDef& target_args_def,
W
wanghuancoder 已提交
364 365
    const TransformFlag& transform_flag,
    bool is_stride_kernel) {
366
  auto pt_tensors = std::make_unique<std::vector<phi::DenseTensor>>();
367 368 369 370
  pt_tensors->reserve(inputs.size());

  for (const auto& input : inputs) {
    const auto& tensor_in = input.impl();
W
wanghuancoder 已提交
371
    auto dense_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(tensor_in);
372 373 374 375 376
    if (!transform_flag.NeedTransform() || !tensor_in->initialized() ||
        (!NeedTransformPlace(
             tensor_in->place(), target_args_def.backend, transform_flag) &&
         !NeedTransformDataType(
             tensor_in->dtype(), target_args_def.dtype, transform_flag) &&
377
         !NeedTransformLayout(tensor_in->layout(),
378
                              target_args_def.layout,
379
                              tensor_in->place(),
W
wanghuancoder 已提交
380 381 382 383
                              transform_flag) &&
         !(dense_tensor &&
           NeedTransform2Contiguous(is_stride_kernel,
                                    dense_tensor->meta().is_contiguous())))) {
384
      pt_tensors->emplace_back(
385
          *std::dynamic_pointer_cast<phi::DenseTensor>(tensor_in));
386 387
    } else {
      pt_tensors->emplace_back(
388
          TransformData(*(static_cast<phi::DenseTensor*>(tensor_in.get())),
389
                        target_args_def,
W
wanghuancoder 已提交
390 391
                        transform_flag,
                        is_stride_kernel));
392 393 394
    }
  }

395
  return pt_tensors;
396 397
}

398 399 400
paddle::optional<std::vector<phi::DenseTensor>> PrepareData(
    const paddle::optional<std::vector<Tensor>>& inputs,
    const phi::TensorArgDef& target_args_def,
W
wanghuancoder 已提交
401 402
    const TransformFlag& transform_flag,
    bool is_stride_kernel) {
403
  if (inputs) {
W
wanghuancoder 已提交
404 405
    return {*PrepareData(
        *inputs, target_args_def, transform_flag, is_stride_kernel)};
406 407 408 409
  }
  return paddle::none;
}

410 411 412 413 414 415 416 417
std::shared_ptr<phi::SelectedRows> PrepareDataForSelectedRows(
    const Tensor& input,
    const phi::TensorArgDef& target_args_def,
    const TransformFlag& transform_flag) {
  const auto& tensor_in = input.impl();
  if (tensor_in) {
    phi::SelectedRows& selected_rows =
        *static_cast<phi::SelectedRows*>(tensor_in.get());
W
wanghuancoder 已提交
418 419 420 421 422 423
    if ((!transform_flag.NeedTransform() || !selected_rows.initialized() ||
         (!NeedTransformPlace(selected_rows.place(),
                              target_args_def.backend,
                              transform_flag))) &&
        !NeedTransform2Contiguous(
            false, selected_rows.value().meta().is_contiguous())) {
424 425 426 427
      return std::static_pointer_cast<phi::SelectedRows>(tensor_in);
    }

    if (selected_rows.place().GetType() == AllocationType::GPUPINNED) {
W
wanghuancoder 已提交
428 429 430 431 432 433 434 435 436 437 438 439 440
      if (NeedTransform2Contiguous(
              false, selected_rows.value().meta().is_contiguous())) {
        auto dense_out = Trans2Contiguous(selected_rows.value());
        selected_rows.mutable_value()->ShareDataWith(dense_out);
      }
      if (transform_flag.NeedTransform() && selected_rows.initialized() &&
          NeedTransformPlace(
              selected_rows.place(), target_args_def.backend, transform_flag)) {
        auto dense_out =
            TransDataPlace(selected_rows.value(),
                           phi::TransToPhiPlace(target_args_def.backend));
        selected_rows.mutable_value()->ShareBufferWith(dense_out);
      }
441
      return std::static_pointer_cast<phi::SelectedRows>(tensor_in);
W
wanghuancoder 已提交
442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458
    } else {
      auto out_new = std::make_shared<phi::SelectedRows>(
          selected_rows.rows(), selected_rows.height());
      if (NeedTransform2Contiguous(
              false, selected_rows.value().meta().is_contiguous())) {
        auto dense_out = Trans2Contiguous(selected_rows.value());
        *out_new->mutable_value() = dense_out;
      }
      if (transform_flag.NeedTransform() && selected_rows.initialized() &&
          NeedTransformPlace(
              selected_rows.place(), target_args_def.backend, transform_flag)) {
        auto dense_out =
            TransDataPlace(selected_rows.value(),
                           phi::TransToPhiPlace(target_args_def.backend));
        *out_new->mutable_value() = dense_out;
      }
      return out_new;
459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475
    }
  }
  PADDLE_THROW(phi::errors::InvalidArgument(
      "The impl() of input tensor is nullptr, it doesn't support for "
      "selected_rows data transform now."));
}

paddle::optional<phi::SelectedRows> PrepareDataForSelectedRows(
    const paddle::optional<Tensor>& input,
    const phi::TensorArgDef& target_args_def,
    const TransformFlag& transform_flag) {
  if (input) {
    return *PrepareDataForSelectedRows(*input, target_args_def, transform_flag);
  }
  return paddle::none;
}

W
wanghuancoder 已提交
476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574
std::shared_ptr<phi::SparseCooTensor> PrepareDataForSparseCooTensor(
    const Tensor& input) {
  const auto& tensor_in = input.impl();
  if (tensor_in) {
    phi::SparseCooTensor& sparse_tensor =
        *static_cast<phi::SparseCooTensor*>(tensor_in.get());
    if (sparse_tensor.indices().meta().is_contiguous() &&
        sparse_tensor.values().meta().is_contiguous()) {
      return std::static_pointer_cast<phi::SparseCooTensor>(tensor_in);
    }

    if (!sparse_tensor.indices().meta().is_contiguous()) {
      *sparse_tensor.mutable_indices() =
          Trans2Contiguous(sparse_tensor.indices());
    }

    if (!sparse_tensor.values().meta().is_contiguous()) {
      *sparse_tensor.mutable_values() =
          Trans2Contiguous(sparse_tensor.values());
    }
    return std::static_pointer_cast<phi::SparseCooTensor>(tensor_in);
  }
  PADDLE_THROW(phi::errors::InvalidArgument(
      "The impl() of input tensor is nullptr, it doesn't support for "
      "SparseCooTensor data transform now."));
}

paddle::optional<phi::SparseCooTensor> PrepareDataForSparseCooTensor(
    const paddle::optional<Tensor>& input) {
  if (input) {
    return *PrepareDataForSparseCooTensor(*input);
  }
  return paddle::none;
}

std::shared_ptr<phi::SparseCsrTensor> PrepareDataForSparseCsrTensor(
    const Tensor& input) {
  const auto& tensor_in = input.impl();
  if (tensor_in) {
    phi::SparseCsrTensor& sparse_tensor =
        *static_cast<phi::SparseCsrTensor*>(tensor_in.get());
    if (sparse_tensor.crows().meta().is_contiguous() &&
        sparse_tensor.cols().meta().is_contiguous() &&
        sparse_tensor.values().meta().is_contiguous()) {
      return std::static_pointer_cast<phi::SparseCsrTensor>(tensor_in);
    }

    if (!sparse_tensor.crows().meta().is_contiguous()) {
      *sparse_tensor.mutable_crows() = Trans2Contiguous(sparse_tensor.crows());
    }

    if (!sparse_tensor.cols().meta().is_contiguous()) {
      *sparse_tensor.mutable_cols() = Trans2Contiguous(sparse_tensor.cols());
    }

    if (!sparse_tensor.values().meta().is_contiguous()) {
      *sparse_tensor.mutable_values() =
          Trans2Contiguous(sparse_tensor.values());
    }
    return std::static_pointer_cast<phi::SparseCsrTensor>(tensor_in);
  }
  PADDLE_THROW(phi::errors::InvalidArgument(
      "The impl() of input tensor is nullptr, it doesn't support for "
      "SparseCsrTensor data transform now."));
}

paddle::optional<phi::SparseCsrTensor> PrepareDataForSparseCsrTensor(
    const paddle::optional<Tensor>& input) {
  if (input) {
    return *PrepareDataForSparseCsrTensor(*input);
  }
  return paddle::none;
}

std::shared_ptr<phi::DenseTensor> PrepareDataForDenseTensorInSparse(
    const Tensor& input) {
  const auto& tensor_in = input.impl();
  if (tensor_in) {
    phi::DenseTensor& dense_tensor =
        *static_cast<phi::DenseTensor*>(tensor_in.get());
    if (dense_tensor.meta().is_contiguous()) {
      return std::static_pointer_cast<phi::DenseTensor>(tensor_in);
    }

    return std::make_shared<phi::DenseTensor>(
        std::move(Trans2Contiguous(dense_tensor)));
  }
  PADDLE_THROW(phi::errors::InvalidArgument(
      "The impl() of input tensor is nullptr, it doesn't support for "
      "DenseTensor data transform now."));
}

paddle::optional<phi::DenseTensor> PrepareDataForDenseTensorInSparse(
    const paddle::optional<Tensor>& input) {
  if (input) {
    return *PrepareDataForDenseTensorInSparse(*input);
  }
  return paddle::none;
}
575 576 577
void TransDataBackend(const phi::DenseTensor* tensor,
                      Backend target_backend,
                      phi::DenseTensor* out) {
578
  if (tensor && tensor->initialized()) {
579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599
    *out = TransDataPlace(*tensor, phi::TransToPhiPlace(target_backend));
  }
}

void TransDataBackend(const std::vector<phi::DenseTensor*>& tensors,
                      Backend target_backend,
                      std::vector<phi::DenseTensor*> outs) {
  size_t n = tensors.size();
  for (size_t i = 0; i < n; ++i) {
    TransDataBackend(tensors[i], target_backend, outs[i]);
  }
}

void TransDataBackend(const phi::SelectedRows* tensor,
                      Backend target_backend,
                      phi::SelectedRows* out) {
  if (tensor) {
    TransDataBackend(&tensor->value(), target_backend, out->mutable_value());
  }
}

600 601
/* ------------------ for auto parallel ----------------------- */

602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621
std::shared_ptr<phi::distributed::DistTensor> ReshardDistTensor(
    phi::DeviceContext* dev_ctx,
    const Tensor& tensor,
    const phi::distributed::TensorDistAttr& dist_attr) {
  auto tensor_in = tensor.impl();
  if (tensor_in) {
    phi::distributed::DistTensor* dist_tensor =
        static_cast<phi::distributed::DistTensor*>(tensor_in.get());
    if (dist_tensor->dist_attr() != dist_attr) {
      VLOG(6) << "Reshard tensor from " << dist_tensor->dist_attr() << " to "
              << dist_attr;
      auto* func = phi::distributed::ChooseProperReshardFunction(*dist_tensor,
                                                                 dist_attr);
      return func->Eval(dev_ctx, *dist_tensor, dist_attr);
    }
    return std::static_pointer_cast<phi::distributed::DistTensor>(tensor_in);
  }
  return nullptr;
}

622 623 624 625 626
std::shared_ptr<phi::distributed::DistTensor> PrepareDataForDistTensor(
    const Tensor& input,
    const phi::TensorArgDef& target_args_def,
    const TransformFlag& transform_flag,
    bool is_stride_kernel) {
627 628 629 630 631 632 633 634 635 636 637 638 639 640
  return PrepareDataForDistTensor(
      std::static_pointer_cast<phi::distributed::DistTensor>(input.impl()),
      target_args_def,
      transform_flag,
      is_stride_kernel);
}

std::shared_ptr<phi::distributed::DistTensor> PrepareDataForDistTensor(
    const std::shared_ptr<phi::distributed::DistTensor>& input,
    const phi::TensorArgDef& target_args_def,
    const TransformFlag& transform_flag,
    bool is_stride_kernel) {
  if (input) {
    phi::distributed::DistTensor* dist_tensor = input.get();
641
    const phi::DenseTensor& dense_tensor = dist_tensor->value();
642 643 644 645 646 647 648 649 650 651 652
    if (!transform_flag.NeedTransform() || !dense_tensor.initialized() ||
        (!NeedTransformPlace(
             dense_tensor.place(), target_args_def.backend, transform_flag) &&
         !NeedTransformDataType(
             dense_tensor.dtype(), target_args_def.dtype, transform_flag) &&
         !NeedTransformLayout(dense_tensor.layout(),
                              target_args_def.layout,
                              dense_tensor.place(),
                              transform_flag) &&
         !NeedTransform2Contiguous(is_stride_kernel,
                                   dense_tensor.meta().is_contiguous()))) {
653
      return input;
654 655 656 657 658
    }
    // TODO(chenweihang): The global meta in DistTensor is not changed,
    // but the local meta in DenseTensor maybe changed, such as layout
    // change(NCHW->NHWC), so the new DistTensor's meta maybe not unified.
    VLOG(6) << "PrepareDataForDistTensor return transformed dist tensor";
659 660 661 662 663 664
    auto dist_out = std::make_shared<phi::distributed::DistTensor>(
        dist_tensor->dims(), dist_tensor->dist_attr());
    auto* out = dist_out->unsafe_mutable_value();
    *out = TransformData(
        dense_tensor, target_args_def, transform_flag, is_stride_kernel);
    return dist_out;
665 666 667 668
  }
  return nullptr;
}

669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710
std::vector<std::shared_ptr<phi::distributed::DistTensor>>
PrepareDataForDistTensor(const std::vector<Tensor>& input,
                         const phi::TensorArgDef& target_args_def,
                         const TransformFlag& transform_flag,
                         bool is_stride_kernel) {
  std::vector<std::shared_ptr<phi::distributed::DistTensor>> out;
  for (auto x : input) {
    const auto& tensor_in = x.impl();
    if (tensor_in) {
      phi::distributed::DistTensor* dist_tensor =
          static_cast<phi::distributed::DistTensor*>(tensor_in.get());
      const phi::DenseTensor& dense_tensor = dist_tensor->value();
      if (!transform_flag.NeedTransform() || !dense_tensor.initialized() ||
          (!NeedTransformPlace(
               dense_tensor.place(), target_args_def.backend, transform_flag) &&
           !NeedTransformDataType(
               dense_tensor.dtype(), target_args_def.dtype, transform_flag) &&
           !NeedTransformLayout(dense_tensor.layout(),
                                target_args_def.layout,
                                dense_tensor.place(),
                                transform_flag) &&
           !NeedTransform2Contiguous(is_stride_kernel,
                                     dense_tensor.meta().is_contiguous()))) {
        out.push_back(
            std::static_pointer_cast<phi::distributed::DistTensor>(tensor_in));
        continue;
      }
      phi::DenseTensor trans_in_tensor = TransformData(
          dense_tensor, target_args_def, transform_flag, is_stride_kernel);
      // TODO(GhostScreaming): The global meta in DistTensor is not changed,
      // but the local meta in DenseTensor maybe changed, such as layout
      // change(NCHW->NHWC), so the new DistTensor's meta maybe not unified.
      VLOG(6) << "PrepareDataForDistTensor return transformed dist tensor";
      out.push_back(std::make_shared<phi::distributed::DistTensor>(
          trans_in_tensor, dist_tensor->dist_attr()));
    } else {
      out.push_back(nullptr);
    }
  }
  return out;
}

711 712
}  // namespace experimental
}  // namespace paddle