data_transform.h 7.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17 18
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/core/kernel_factory.h"
19
#include "paddle/phi/core/selected_rows.h"
W
wanghuancoder 已提交
20 21
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"
22

23
namespace phi {
24
class DeviceContext;
25 26
namespace distributed {
class DistTensor;
27
class TensorDistAttr;
28 29 30
}  // namespace distributed
}  // namespace phi

31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
namespace paddle {
namespace experimental {

class TransformFlag {
 public:
  TransformFlag(bool stop_transform = false,
                bool trans_dtype = false,
                bool trans_backend = true,
                bool trans_layout = true)
      : stop_transform_(stop_transform),
        trans_data_type_(trans_dtype),
        trans_backend_(trans_backend),
        trans_layout_(trans_layout) {}

  bool NeedTransform() const {
    return !stop_transform_ &&
           (trans_data_type_ || trans_backend_ || trans_layout_);
  }

  bool need_trans_data_type() const {
    return !stop_transform_ && trans_data_type_;
  }

  bool need_trans_backend() const { return !stop_transform_ && trans_backend_; }

  bool need_trans_layout() const { return !stop_transform_ && trans_layout_; }

 private:
  // This is the highest priority in flags,
  // and can be setted by api[data_transform->skip_transform] in the yaml file.
  bool stop_transform_ = false;

  // trans_data_type_ can be setted by api[data_transform->support_trans_dtype]
  // in the yaml file.
  // trans_data_type_ only affect the non complex types,
  // the complex is always transferd, except stop_transform_ is true.
  bool trans_data_type_ = false;

69
  // trans_backend_ and trans_layout_ are true defaultly,
70 71 72 73 74
  // and they can only be setted by global flag.
  bool trans_backend_ = true;
  bool trans_layout_ = true;
};

75 76 77 78 79 80 81 82 83 84 85 86 87
static inline phi::TensorArgDef GetKernelInputArgDef(
    const phi::TensorArgDef& input_def, phi::Backend kernel_backend) {
  phi::TensorArgDef input_actual_def = input_def;
#ifdef PADDLE_WITH_CUSTOM_DEVICE
  // When the backend of input tensor arg_def is CUSTOM, we need to set it to
  // the actual backend by expected_kernel_key.
  if (input_actual_def.backend == phi::Backend::CUSTOM) {
    input_actual_def.SetBackend(kernel_backend);
  }
#endif
  return input_actual_def;
}

88
std::shared_ptr<phi::DenseTensor> PrepareData(
89
    const Tensor& input,
90
    const phi::TensorArgDef& target_args_def,
W
wanghuancoder 已提交
91 92
    const TransformFlag& transform_flag,
    bool is_stride_kernel);
93

94
paddle::optional<phi::DenseTensor> PrepareData(
95 96
    const paddle::optional<Tensor>& input,
    const phi::TensorArgDef& target_args_def,
W
wanghuancoder 已提交
97 98
    const TransformFlag& transform_flag,
    bool is_stride_kernel);
99

100
std::unique_ptr<std::vector<phi::DenseTensor>> PrepareData(
101
    const std::vector<Tensor>& inputs,
102
    const phi::TensorArgDef& target_args_def,
W
wanghuancoder 已提交
103 104
    const TransformFlag& transform_flag,
    bool is_stride_kernel);
105

106 107 108
paddle::optional<std::vector<phi::DenseTensor>> PrepareData(
    const paddle::optional<std::vector<Tensor>>& inputs,
    const phi::TensorArgDef& target_args_def,
W
wanghuancoder 已提交
109 110
    const TransformFlag& transform_flag,
    bool is_stride_kernel);
111

112 113 114 115 116 117 118 119 120 121 122
// Only support transfering place for SelectedRows
std::shared_ptr<phi::SelectedRows> PrepareDataForSelectedRows(
    const Tensor& input,
    const phi::TensorArgDef& target_args_def,
    const TransformFlag& transform_flag);

paddle::optional<phi::SelectedRows> PrepareDataForSelectedRows(
    const paddle::optional<Tensor>& input,
    const phi::TensorArgDef& target_args_def,
    const TransformFlag& transform_flag);

W
wanghuancoder 已提交
123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
// Only support transfering contiguous for SparseCooTensor
std::shared_ptr<phi::SparseCooTensor> PrepareDataForSparseCooTensor(
    const Tensor& input);

paddle::optional<phi::SparseCooTensor> PrepareDataForSparseCooTensor(
    const paddle::optional<Tensor>& input);

// Only support transfering contiguous for SparseCsrTensor
std::shared_ptr<phi::SparseCsrTensor> PrepareDataForSparseCsrTensor(
    const Tensor& input);

paddle::optional<phi::SparseCsrTensor> PrepareDataForSparseCsrTensor(
    const paddle::optional<Tensor>& input);

// Only support transfering contiguous
std::shared_ptr<phi::DenseTensor> PrepareDataForDenseTensorInSparse(
    const Tensor& input);

paddle::optional<phi::DenseTensor> PrepareDataForDenseTensorInSparse(
    const paddle::optional<Tensor>& input);

144 145 146 147 148 149 150 151 152 153 154 155
void TransDataBackend(const phi::DenseTensor* tensor,
                      Backend target_backend,
                      phi::DenseTensor* out);

void TransDataBackend(const std::vector<phi::DenseTensor*>& tensor,
                      Backend target_backend,
                      std::vector<phi::DenseTensor*> out);

void TransDataBackend(const phi::SelectedRows* tensor,
                      Backend target_backend,
                      phi::SelectedRows* out);

W
wanghuancoder 已提交
156 157 158
phi::DenseTensor Trans2Contiguous(const phi::DenseTensor& tensor);

void CheckAndTrans2Contiguous(phi::DenseTensor* tensor);
H
hong 已提交
159
inline bool NeedTransformPlace(const phi::Place& src_place,
H
hong 已提交
160 161 162 163 164 165 166 167 168
                               const Backend& target,
                               const TransformFlag& transform_flag) {
  // NOTE(dev): The default value of TransformFlag is True, if it is set with
  // False
  // somewhere such as ops.yaml or backward.yaml that means we should skip data
  // transform. Because "stop_transform_" has highest priority.
  if (!transform_flag.need_trans_backend()) {
    return false;
  }
H
hong 已提交
169
  bool ret = src_place.GetType() == AllocationType::GPUPINNED ||
H
hong 已提交
170
             (target != Backend::ALL_BACKEND &&
H
hong 已提交
171
              phi::TransToPhiBackend(src_place) !=
H
hong 已提交
172 173 174 175
                  (target != Backend::GPUDNN ? target : Backend::GPU));
  return ret;
}

176 177
/* ------------------ for auto parallel ----------------------- */

178 179 180 181 182
std::shared_ptr<phi::distributed::DistTensor> ReshardDistTensor(
    phi::DeviceContext* dev_ctx,
    const Tensor& tensor,
    const phi::distributed::TensorDistAttr& dist_attr);

183 184 185 186 187 188
std::shared_ptr<phi::distributed::DistTensor> PrepareDataForDistTensor(
    const Tensor& input,
    const phi::TensorArgDef& target_args_def,
    const TransformFlag& transform_flag,
    bool is_stride_kernel);

189 190 191 192 193 194
std::shared_ptr<phi::distributed::DistTensor> PrepareDataForDistTensor(
    const std::shared_ptr<phi::distributed::DistTensor>& input,
    const phi::TensorArgDef& target_args_def,
    const TransformFlag& transform_flag,
    bool is_stride_kernel);

195 196 197 198 199 200
std::vector<std::shared_ptr<phi::distributed::DistTensor>>
PrepareDataForDistTensor(const std::vector<Tensor>& input,
                         const phi::TensorArgDef& target_args_def,
                         const TransformFlag& transform_flag,
                         bool is_stride_kernel);

201 202
}  // namespace experimental
}  // namespace paddle