data_transform.cc 23.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/phi/api/lib/data_transform.h"
16

17 18
#include "glog/logging.h"

W
wanghuancoder 已提交
19 20
#include "gflags/gflags.h"
#include "paddle/fluid/platform/device_context.h"
21
#include "paddle/phi/api/lib/kernel_dispatch.h"
22
#include "paddle/phi/api/lib/utils/allocator.h"
23
#include "paddle/phi/backends/context_pool.h"
24
#include "paddle/phi/core/kernel_registry.h"
25
#include "paddle/phi/core/tensor_utils.h"
W
wanghuancoder 已提交
26
#include "paddle/phi/core/visit_type.h"
27
#include "paddle/phi/kernels/cast_kernel.h"
W
wanghuancoder 已提交
28
#include "paddle/phi/kernels/contiguous_kernel.h"
29
#include "paddle/phi/kernels/transfer_layout_kernel.h"
30 31 32 33
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#endif

W
wanghuancoder 已提交
34
DECLARE_bool(use_stride_kernel);
35 36 37 38 39 40 41 42 43 44 45 46

namespace paddle {
namespace experimental {

inline bool NeedTransformDataType(const DataType& input,
                                  const DataType& target,
                                  const TransformFlag& transform_flag) {
  return input != target &&
         (transform_flag.need_trans_data_type() ||
          target == DataType::COMPLEX64 || target == DataType::COMPLEX128);
}

47
inline bool NeedTransformLayout(const DataLayout& input,
48
                                const DataLayout& target,
49
                                const phi::Place& place,
50
                                const TransformFlag& transform_flag) {
W
wanghuancoder 已提交
51 52 53 54
  if (FLAGS_use_stride_kernel && target == DataLayout::STRIDED) {
    return false;
  }

55 56 57
  bool ret = transform_flag.need_trans_layout() &&
             (input != DataLayout::ALL_LAYOUT &&
              target != DataLayout::ALL_LAYOUT && input != target);
58
  if (place.GetType() == phi::AllocationType::GPU) {
59 60
    return false;
  }
61 62 63
  return ret;
}

W
wanghuancoder 已提交
64 65 66 67 68
inline bool NeedTransform2Contiguous(bool is_stride_kernel,
                                     bool is_contiguous) {
  return FLAGS_use_stride_kernel && !is_stride_kernel && !is_contiguous;
}

69 70
inline phi::DenseTensor TransDataLayout(const phi::DenseTensor& tensor,
                                        DataLayout layout) {
71
  auto& pool = phi::DeviceContextPool::Instance();
72 73
  VLOG(3) << "DataLayoutTransform src_layout: " << tensor.layout()
          << " dst_layout: " << layout;
74
  if (tensor.place().GetType() == phi::AllocationType::CPU) {
75 76
    auto* dev_ctx = static_cast<phi::CPUContext*>(pool.Get(tensor.place()));
    return phi::TransferLayout(*dev_ctx, tensor, layout);
77
  } else {
78
    PADDLE_THROW(phi::errors::PreconditionNotMet(
79 80
        "Unsupported data layout cast from CPU to GPU."));
  }
81
  return tensor;
82 83 84
}

template <typename Context>
85
phi::DenseTensor CastDataType(const Context& dev_ctx,
86 87
                              const phi::DenseTensor& tensor,
                              DataType dtype) {
88 89
  switch (tensor.dtype()) {
    case DataType::FLOAT32:
90
      return phi::Cast<float>(dev_ctx, tensor, dtype);
91
    case DataType::FLOAT64:
92
      return phi::Cast<double>(dev_ctx, tensor, dtype);
93
    case DataType::INT32:
94
      return phi::Cast<int32_t>(dev_ctx, tensor, dtype);
95
    case DataType::INT64:
96
      return phi::Cast<int64_t>(dev_ctx, tensor, dtype);
97
    case DataType::FLOAT16:
98
      return phi::Cast<phi::dtype::float16>(dev_ctx, tensor, dtype);
99
    case DataType::BFLOAT16:
100
      return phi::Cast<phi::dtype::bfloat16>(dev_ctx, tensor, dtype);
101
    case DataType::BOOL:
102
      return phi::Cast<bool>(dev_ctx, tensor, dtype);
103
    case DataType::INT16:
104
      return phi::Cast<int16_t>(dev_ctx, tensor, dtype);
105
    case DataType::UINT8:
106
      return phi::Cast<uint8_t>(dev_ctx, tensor, dtype);
107
    default:
108
      PADDLE_THROW(phi::errors::Unimplemented(
109 110 111 112 113 114
          "Data type (%s) is not supported when casting data type.",
          tensor.dtype()));
  }
}

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
115
phi::DenseTensor CastDataType(const phi::GPUContext& dev_ctx,
116 117
                              const phi::DenseTensor& tensor,
                              DataType dtype) {
118 119
  switch (tensor.dtype()) {
    case DataType::FLOAT32:
120
      return phi::Cast<float>(dev_ctx, tensor, dtype);
121
    case DataType::FLOAT64:
122
      return phi::Cast<double>(dev_ctx, tensor, dtype);
123
    case DataType::INT32:
124
      return phi::Cast<int32_t>(dev_ctx, tensor, dtype);
125
    case DataType::INT64:
126
      return phi::Cast<int64_t>(dev_ctx, tensor, dtype);
127
    case DataType::FLOAT16:
128
      return phi::Cast<phi::dtype::float16>(dev_ctx, tensor, dtype);
129
    case DataType::BOOL:
130
      return phi::Cast<bool>(dev_ctx, tensor, dtype);
131
    case DataType::INT16:
132
      return phi::Cast<int16_t>(dev_ctx, tensor, dtype);
133
    case DataType::UINT8:
134
      return phi::Cast<uint8_t>(dev_ctx, tensor, dtype);
135
    default:
136
      PADDLE_THROW(phi::errors::Unimplemented(
137 138 139 140 141 142
          "Data type (%s) is not supported when casting data type.",
          tensor.dtype()));
  }
}
#endif

143 144
inline phi::DenseTensor TransDataType(const phi::DenseTensor& tensor,
                                      DataType dtype) {
145
  auto& pool = phi::DeviceContextPool::Instance();
146 147 148 149

  VLOG(3) << "DataTypeTransform src_dtype: " << tensor.dtype()
          << " dst_dtype: " << dtype;

150 151
  DefaultAllocator alloc(tensor.place());
  phi::DenseTensor out(&alloc, {dtype, tensor.dims(), tensor.layout()});
152

153
  if (tensor.place().GetType() == phi::AllocationType::CPU) {
154
    auto* dev_ctx = static_cast<phi::CPUContext*>(pool.Get(tensor.place()));
155
    return CastDataType(*dev_ctx, tensor, dtype);
156
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
157
  } else if (tensor.place().GetType() == phi::AllocationType::GPU) {
158
    auto* dev_ctx = static_cast<phi::GPUContext*>(pool.Get(tensor.place()));
159
    return CastDataType(*dev_ctx, tensor, dtype);
160 161
#endif
  } else {
162
    PADDLE_THROW(phi::errors::Unimplemented(
163 164 165 166 167
        "Place type is not supported when casting data type."));
  }
  return out;
}

168 169 170 171 172
inline phi::DenseTensor TransDataPlace(const phi::DenseTensor& tensor,
                                       Place dst_place) {
  VLOG(3) << "DeviceTransform in, src_place " << tensor.place()
          << " dst_place: " << dst_place;

E
engineer1109 已提交
173
  auto& pool = phi::DeviceContextPool::Instance();
174 175
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
  // NOTE(yy): TransDataPlace should wait for computation of input.
176
  if (tensor.place().GetType() != phi::AllocationType::GPUPINNED) {
177 178 179 180 181 182 183 184 185 186 187 188 189 190
    pool.Get(tensor.place())->Wait();
    pool.Get(dst_place)->Wait();
  }
#endif

  // FIXME(zcd): TransDataPlace is used to transform data from GPU to CPU and
  // the enforced checkings have been done in GetDeviceContext, so the
  // `dev_ctx->Wait()` is necessary. But `dev_ctx->Wait()` will make the program
  // slow, especially when the number of elements is little, for example,
  // the elements of learning rate are one and it's CPU side.
  // One solution is to use a CUDA kernel to complete the copy operation when
  // the transforming is from CPU to GPU and the number of elements is little.
  // But the embarrassment is that this solution this solution makes training
  // slower.
191
  phi::DenseTensor out;
E
engineer1109 已提交
192 193 194 195 196 197 198
  phi::DeviceContext* dev_ctx;
  if (dst_place.GetType() != AllocationType::CPU) {
    dev_ctx = pool.Get(dst_place);
  } else {
    dev_ctx = pool.Get(tensor.place());
  }
  phi::Copy(*dev_ctx, tensor, dst_place, true, &out);
199 200 201
  return out;
}

W
wanghuancoder 已提交
202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249
template <typename Context>
phi::DenseTensor TensorContiguous(const Context& dev_ctx,
                                  const phi::DenseTensor& tensor) {
  phi::DenseTensor dense_out;
  phi::MetaTensor meta_input(tensor);
  phi::MetaTensor meta_out(&dense_out);
  UnchangedInferMeta(meta_input, &meta_out);

  PD_VISIT_ALL_TYPES(tensor.dtype(), "TensorContiguous", ([&] {
                       phi::ContiguousKernel<data_t, Context>(
                           dev_ctx, tensor, &dense_out);
                     }));
  return dense_out;
}

phi::DenseTensor Trans2Contiguous(const phi::DenseTensor& tensor) {
  auto& pool = paddle::platform::DeviceContextPool::Instance();

  VLOG(3) << "Trans2Contiguous...";

  if (tensor.place().GetType() == phi::AllocationType::CPU) {
    auto* dev_ctx = static_cast<phi::CPUContext*>(pool.Get(tensor.place()));
    return TensorContiguous<phi::CPUContext>(*dev_ctx, tensor);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
  } else if (tensor.place().GetType() == phi::AllocationType::GPU) {
    auto* dev_ctx = static_cast<phi::GPUContext*>(pool.Get(tensor.place()));
    return TensorContiguous<phi::GPUContext>(*dev_ctx, tensor);
#endif
#ifdef PADDLE_WITH_XPU
  } else if (tensor.place().GetType() == phi::AllocationType::XPU) {
    auto* dev_ctx = static_cast<phi::XPUContext*>(pool.Get(tensor.place()));
    return TensorContiguous<phi::XPUContext>(*dev_ctx, tensor);
#endif
  } else {
    PADDLE_THROW(phi::errors::Unimplemented(
        "Place type is not supported when casting data type."));
  }

  return tensor;
}

void CheckAndTrans2Contiguous(phi::DenseTensor* tensor) {
  if (!tensor->meta().is_contiguous()) {
    phi::DenseTensor tmp = Trans2Contiguous(*tensor);
    tensor->ShareDataWith(tmp);
  }
}

250
phi::DenseTensor TransformData(const phi::DenseTensor& tensor,
251
                               const phi::TensorArgDef& target_args_def,
W
wanghuancoder 已提交
252 253
                               const TransformFlag& transform_flag,
                               bool is_stride_kernel) {
254
  phi::DenseTensor out = tensor;
255 256
  bool trans_layout = false;
  bool trans_dtype = false;
257

W
wanghuancoder 已提交
258 259 260 261
  if (NeedTransform2Contiguous(is_stride_kernel, out.meta().is_contiguous())) {
    out = Trans2Contiguous(out);
  }

262
  if (NeedTransformLayout(tensor.layout(),
263
                          target_args_def.layout,
264
                          tensor.place(),
265
                          transform_flag) &&
266
      tensor.dims().size() != 1) {
W
wanghuancoder 已提交
267 268 269
    if (NeedTransform2Contiguous(false, out.meta().is_contiguous())) {
      out = Trans2Contiguous(out);
    }
270
    out = TransDataLayout(out, target_args_def.layout);
271
    trans_layout = true;
272 273 274
  }

  if (NeedTransformDataType(
275
          tensor.dtype(), target_args_def.dtype, transform_flag)) {
W
wanghuancoder 已提交
276 277 278
    if (NeedTransform2Contiguous(false, out.meta().is_contiguous())) {
      out = Trans2Contiguous(out);
    }
279
    out = TransDataType(out, target_args_def.dtype);
280
    trans_dtype = true;
281 282 283 284
  }

  if (NeedTransformPlace(
          out.place(), target_args_def.backend, transform_flag)) {
285
    out = TransDataPlace(out, phi::TransToPhiPlace(target_args_def.backend));
286
    if (!trans_layout && !trans_dtype &&
287 288 289 290 291 292 293 294
        tensor.place().GetType() == AllocationType::GPUPINNED) {
      // Sharing buffer on GPUPINNED place is a special case due to historical
      // reasons, and it should not be implemented in this way from a
      // reasonable point of view, but because the performance of the previous
      // model depends on the inplace operation here, the model performance
      // will deteriorate after reverting to non-place impl, so it needs to be
      // retained here and need to use `const_cast`
      const_cast<phi::DenseTensor&>(tensor).ShareBufferWith(out);
295
    }
296 297 298 299
  }
  return out;
}

300
std::shared_ptr<phi::DenseTensor> PrepareData(
301
    const Tensor& input,
302
    const phi::TensorArgDef& target_args_def,
W
wanghuancoder 已提交
303 304
    const TransformFlag& transform_flag,
    bool is_stride_kernel) {
305
  const auto& tensor_in = input.impl();
Z
zyfncg 已提交
306 307 308 309 310 311 312 313
  if (tensor_in) {
    phi::DenseTensor& dense_tensor =
        *static_cast<phi::DenseTensor*>(tensor_in.get());
    if (!transform_flag.NeedTransform() || !dense_tensor.initialized() ||
        (!NeedTransformPlace(
             dense_tensor.place(), target_args_def.backend, transform_flag) &&
         !NeedTransformDataType(
             dense_tensor.dtype(), target_args_def.dtype, transform_flag) &&
314
         !NeedTransformLayout(dense_tensor.layout(),
315
                              target_args_def.layout,
316
                              dense_tensor.place(),
W
wanghuancoder 已提交
317 318 319
                              transform_flag) &&
         !NeedTransform2Contiguous(is_stride_kernel,
                                   dense_tensor.meta().is_contiguous()))) {
Z
zyfncg 已提交
320 321
      return std::static_pointer_cast<phi::DenseTensor>(tensor_in);
    }
W
wanghuancoder 已提交
322
    phi::DenseTensor out = TransformData(
323
        dense_tensor, target_args_def, transform_flag, is_stride_kernel);
Z
zyfncg 已提交
324
    return std::make_shared<phi::DenseTensor>(std::move(out));
325
  }
Z
zyfncg 已提交
326
  return nullptr;
327 328
}

329
paddle::optional<phi::DenseTensor> PrepareData(
330 331
    const paddle::optional<Tensor>& input,
    const phi::TensorArgDef& target_args_def,
W
wanghuancoder 已提交
332 333
    const TransformFlag& transform_flag,
    bool is_stride_kernel) {
334
  if (input) {
W
wanghuancoder 已提交
335 336
    return {*PrepareData(
        *input, target_args_def, transform_flag, is_stride_kernel)};
H
hong 已提交
337
  }
338
  return paddle::none;
H
hong 已提交
339 340
}

341
std::unique_ptr<std::vector<phi::DenseTensor>> PrepareData(
342
    const std::vector<Tensor>& inputs,
343
    const phi::TensorArgDef& target_args_def,
W
wanghuancoder 已提交
344 345
    const TransformFlag& transform_flag,
    bool is_stride_kernel) {
346
  auto pt_tensors = std::make_unique<std::vector<phi::DenseTensor>>();
347 348 349 350
  pt_tensors->reserve(inputs.size());

  for (const auto& input : inputs) {
    const auto& tensor_in = input.impl();
W
wanghuancoder 已提交
351
    auto dense_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(tensor_in);
352 353 354 355 356
    if (!transform_flag.NeedTransform() || !tensor_in->initialized() ||
        (!NeedTransformPlace(
             tensor_in->place(), target_args_def.backend, transform_flag) &&
         !NeedTransformDataType(
             tensor_in->dtype(), target_args_def.dtype, transform_flag) &&
357
         !NeedTransformLayout(tensor_in->layout(),
358
                              target_args_def.layout,
359
                              tensor_in->place(),
W
wanghuancoder 已提交
360 361 362 363
                              transform_flag) &&
         !(dense_tensor &&
           NeedTransform2Contiguous(is_stride_kernel,
                                    dense_tensor->meta().is_contiguous())))) {
364
      pt_tensors->emplace_back(
365
          *std::dynamic_pointer_cast<phi::DenseTensor>(tensor_in));
366 367
    } else {
      pt_tensors->emplace_back(
368
          TransformData(*(static_cast<phi::DenseTensor*>(tensor_in.get())),
369
                        target_args_def,
W
wanghuancoder 已提交
370 371
                        transform_flag,
                        is_stride_kernel));
372 373 374
    }
  }

375
  return pt_tensors;
376 377
}

378 379 380
paddle::optional<std::vector<phi::DenseTensor>> PrepareData(
    const paddle::optional<std::vector<Tensor>>& inputs,
    const phi::TensorArgDef& target_args_def,
W
wanghuancoder 已提交
381 382
    const TransformFlag& transform_flag,
    bool is_stride_kernel) {
383
  if (inputs) {
W
wanghuancoder 已提交
384 385
    return {*PrepareData(
        *inputs, target_args_def, transform_flag, is_stride_kernel)};
386 387 388 389
  }
  return paddle::none;
}

390 391 392 393 394 395 396 397
std::shared_ptr<phi::SelectedRows> PrepareDataForSelectedRows(
    const Tensor& input,
    const phi::TensorArgDef& target_args_def,
    const TransformFlag& transform_flag) {
  const auto& tensor_in = input.impl();
  if (tensor_in) {
    phi::SelectedRows& selected_rows =
        *static_cast<phi::SelectedRows*>(tensor_in.get());
W
wanghuancoder 已提交
398 399 400 401 402 403
    if ((!transform_flag.NeedTransform() || !selected_rows.initialized() ||
         (!NeedTransformPlace(selected_rows.place(),
                              target_args_def.backend,
                              transform_flag))) &&
        !NeedTransform2Contiguous(
            false, selected_rows.value().meta().is_contiguous())) {
404 405 406 407
      return std::static_pointer_cast<phi::SelectedRows>(tensor_in);
    }

    if (selected_rows.place().GetType() == AllocationType::GPUPINNED) {
W
wanghuancoder 已提交
408 409 410 411 412 413 414 415 416 417 418 419 420
      if (NeedTransform2Contiguous(
              false, selected_rows.value().meta().is_contiguous())) {
        auto dense_out = Trans2Contiguous(selected_rows.value());
        selected_rows.mutable_value()->ShareDataWith(dense_out);
      }
      if (transform_flag.NeedTransform() && selected_rows.initialized() &&
          NeedTransformPlace(
              selected_rows.place(), target_args_def.backend, transform_flag)) {
        auto dense_out =
            TransDataPlace(selected_rows.value(),
                           phi::TransToPhiPlace(target_args_def.backend));
        selected_rows.mutable_value()->ShareBufferWith(dense_out);
      }
421
      return std::static_pointer_cast<phi::SelectedRows>(tensor_in);
W
wanghuancoder 已提交
422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438
    } else {
      auto out_new = std::make_shared<phi::SelectedRows>(
          selected_rows.rows(), selected_rows.height());
      if (NeedTransform2Contiguous(
              false, selected_rows.value().meta().is_contiguous())) {
        auto dense_out = Trans2Contiguous(selected_rows.value());
        *out_new->mutable_value() = dense_out;
      }
      if (transform_flag.NeedTransform() && selected_rows.initialized() &&
          NeedTransformPlace(
              selected_rows.place(), target_args_def.backend, transform_flag)) {
        auto dense_out =
            TransDataPlace(selected_rows.value(),
                           phi::TransToPhiPlace(target_args_def.backend));
        *out_new->mutable_value() = dense_out;
      }
      return out_new;
439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455
    }
  }
  PADDLE_THROW(phi::errors::InvalidArgument(
      "The impl() of input tensor is nullptr, it doesn't support for "
      "selected_rows data transform now."));
}

paddle::optional<phi::SelectedRows> PrepareDataForSelectedRows(
    const paddle::optional<Tensor>& input,
    const phi::TensorArgDef& target_args_def,
    const TransformFlag& transform_flag) {
  if (input) {
    return *PrepareDataForSelectedRows(*input, target_args_def, transform_flag);
  }
  return paddle::none;
}

W
wanghuancoder 已提交
456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554
std::shared_ptr<phi::SparseCooTensor> PrepareDataForSparseCooTensor(
    const Tensor& input) {
  const auto& tensor_in = input.impl();
  if (tensor_in) {
    phi::SparseCooTensor& sparse_tensor =
        *static_cast<phi::SparseCooTensor*>(tensor_in.get());
    if (sparse_tensor.indices().meta().is_contiguous() &&
        sparse_tensor.values().meta().is_contiguous()) {
      return std::static_pointer_cast<phi::SparseCooTensor>(tensor_in);
    }

    if (!sparse_tensor.indices().meta().is_contiguous()) {
      *sparse_tensor.mutable_indices() =
          Trans2Contiguous(sparse_tensor.indices());
    }

    if (!sparse_tensor.values().meta().is_contiguous()) {
      *sparse_tensor.mutable_values() =
          Trans2Contiguous(sparse_tensor.values());
    }
    return std::static_pointer_cast<phi::SparseCooTensor>(tensor_in);
  }
  PADDLE_THROW(phi::errors::InvalidArgument(
      "The impl() of input tensor is nullptr, it doesn't support for "
      "SparseCooTensor data transform now."));
}

paddle::optional<phi::SparseCooTensor> PrepareDataForSparseCooTensor(
    const paddle::optional<Tensor>& input) {
  if (input) {
    return *PrepareDataForSparseCooTensor(*input);
  }
  return paddle::none;
}

std::shared_ptr<phi::SparseCsrTensor> PrepareDataForSparseCsrTensor(
    const Tensor& input) {
  const auto& tensor_in = input.impl();
  if (tensor_in) {
    phi::SparseCsrTensor& sparse_tensor =
        *static_cast<phi::SparseCsrTensor*>(tensor_in.get());
    if (sparse_tensor.crows().meta().is_contiguous() &&
        sparse_tensor.cols().meta().is_contiguous() &&
        sparse_tensor.values().meta().is_contiguous()) {
      return std::static_pointer_cast<phi::SparseCsrTensor>(tensor_in);
    }

    if (!sparse_tensor.crows().meta().is_contiguous()) {
      *sparse_tensor.mutable_crows() = Trans2Contiguous(sparse_tensor.crows());
    }

    if (!sparse_tensor.cols().meta().is_contiguous()) {
      *sparse_tensor.mutable_cols() = Trans2Contiguous(sparse_tensor.cols());
    }

    if (!sparse_tensor.values().meta().is_contiguous()) {
      *sparse_tensor.mutable_values() =
          Trans2Contiguous(sparse_tensor.values());
    }
    return std::static_pointer_cast<phi::SparseCsrTensor>(tensor_in);
  }
  PADDLE_THROW(phi::errors::InvalidArgument(
      "The impl() of input tensor is nullptr, it doesn't support for "
      "SparseCsrTensor data transform now."));
}

paddle::optional<phi::SparseCsrTensor> PrepareDataForSparseCsrTensor(
    const paddle::optional<Tensor>& input) {
  if (input) {
    return *PrepareDataForSparseCsrTensor(*input);
  }
  return paddle::none;
}

std::shared_ptr<phi::DenseTensor> PrepareDataForDenseTensorInSparse(
    const Tensor& input) {
  const auto& tensor_in = input.impl();
  if (tensor_in) {
    phi::DenseTensor& dense_tensor =
        *static_cast<phi::DenseTensor*>(tensor_in.get());
    if (dense_tensor.meta().is_contiguous()) {
      return std::static_pointer_cast<phi::DenseTensor>(tensor_in);
    }

    return std::make_shared<phi::DenseTensor>(
        std::move(Trans2Contiguous(dense_tensor)));
  }
  PADDLE_THROW(phi::errors::InvalidArgument(
      "The impl() of input tensor is nullptr, it doesn't support for "
      "DenseTensor data transform now."));
}

paddle::optional<phi::DenseTensor> PrepareDataForDenseTensorInSparse(
    const paddle::optional<Tensor>& input) {
  if (input) {
    return *PrepareDataForDenseTensorInSparse(*input);
  }
  return paddle::none;
}
555 556 557
void TransDataBackend(const phi::DenseTensor* tensor,
                      Backend target_backend,
                      phi::DenseTensor* out) {
558
  if (tensor && tensor->initialized()) {
559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579
    *out = TransDataPlace(*tensor, phi::TransToPhiPlace(target_backend));
  }
}

void TransDataBackend(const std::vector<phi::DenseTensor*>& tensors,
                      Backend target_backend,
                      std::vector<phi::DenseTensor*> outs) {
  size_t n = tensors.size();
  for (size_t i = 0; i < n; ++i) {
    TransDataBackend(tensors[i], target_backend, outs[i]);
  }
}

void TransDataBackend(const phi::SelectedRows* tensor,
                      Backend target_backend,
                      phi::SelectedRows* out) {
  if (tensor) {
    TransDataBackend(&tensor->value(), target_backend, out->mutable_value());
  }
}

580 581 582 583 584 585 586 587 588 589 590 591
#ifdef PADDLE_WITH_DISTRIBUTE
/* ------------------ for auto parallel ----------------------- */

std::shared_ptr<phi::distributed::DistTensor> PrepareDataForDistTensor(
    const Tensor& input,
    const phi::TensorArgDef& target_args_def,
    const TransformFlag& transform_flag,
    bool is_stride_kernel) {
  const auto& tensor_in = input.impl();
  if (tensor_in) {
    phi::distributed::DistTensor* dist_tensor =
        static_cast<phi::distributed::DistTensor*>(tensor_in.get());
592
    const phi::DenseTensor& dense_tensor = dist_tensor->value();
593 594 595 596 597 598 599 600 601 602 603 604 605 606
    if (!transform_flag.NeedTransform() || !dense_tensor.initialized() ||
        (!NeedTransformPlace(
             dense_tensor.place(), target_args_def.backend, transform_flag) &&
         !NeedTransformDataType(
             dense_tensor.dtype(), target_args_def.dtype, transform_flag) &&
         !NeedTransformLayout(dense_tensor.layout(),
                              target_args_def.layout,
                              dense_tensor.place(),
                              transform_flag) &&
         !NeedTransform2Contiguous(is_stride_kernel,
                                   dense_tensor.meta().is_contiguous()))) {
      return std::static_pointer_cast<phi::distributed::DistTensor>(tensor_in);
    }
    phi::DenseTensor out = TransformData(
607
        dense_tensor, target_args_def, transform_flag, is_stride_kernel);
608 609 610 611 612
    // TODO(chenweihang): The global meta in DistTensor is not changed,
    // but the local meta in DenseTensor maybe changed, such as layout
    // change(NCHW->NHWC), so the new DistTensor's meta maybe not unified.
    VLOG(6) << "PrepareDataForDistTensor return transformed dist tensor";
    return std::make_shared<phi::distributed::DistTensor>(
613
        out, dist_tensor->dist_attr());
614 615 616 617 618
  }
  return nullptr;
}
#endif

619 620
}  // namespace experimental
}  // namespace paddle