data_transform.h 3.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17 18
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/core/kernel_factory.h"
19
#include "paddle/phi/core/selected_rows.h"
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58

namespace paddle {
namespace experimental {

class TransformFlag {
 public:
  TransformFlag(bool stop_transform = false,
                bool trans_dtype = false,
                bool trans_backend = true,
                bool trans_layout = true)
      : stop_transform_(stop_transform),
        trans_data_type_(trans_dtype),
        trans_backend_(trans_backend),
        trans_layout_(trans_layout) {}

  bool NeedTransform() const {
    return !stop_transform_ &&
           (trans_data_type_ || trans_backend_ || trans_layout_);
  }

  bool need_trans_data_type() const {
    return !stop_transform_ && trans_data_type_;
  }

  bool need_trans_backend() const { return !stop_transform_ && trans_backend_; }

  bool need_trans_layout() const { return !stop_transform_ && trans_layout_; }

 private:
  // This is the highest priority in flags,
  // and can be setted by api[data_transform->skip_transform] in the yaml file.
  bool stop_transform_ = false;

  // trans_data_type_ can be setted by api[data_transform->support_trans_dtype]
  // in the yaml file.
  // trans_data_type_ only affect the non complex types,
  // the complex is always transferd, except stop_transform_ is true.
  bool trans_data_type_ = false;

59
  // trans_backend_ and trans_layout_ are true defaultly,
60 61 62 63 64
  // and they can only be setted by global flag.
  bool trans_backend_ = true;
  bool trans_layout_ = true;
};

65
std::shared_ptr<phi::DenseTensor> PrepareData(
66
    const Tensor& input,
67
    const phi::TensorArgDef& target_args_def,
68 69
    const TransformFlag& transform_flag);

70
paddle::optional<phi::DenseTensor> PrepareData(
71 72 73 74
    const paddle::optional<Tensor>& input,
    const phi::TensorArgDef& target_args_def,
    const TransformFlag& transform_flag);

75
std::unique_ptr<std::vector<phi::DenseTensor>> PrepareData(
76
    const std::vector<Tensor>& inputs,
77
    const phi::TensorArgDef& target_args_def,
78 79
    const TransformFlag& transform_flag);

80 81 82 83 84
paddle::optional<std::vector<phi::DenseTensor>> PrepareData(
    const paddle::optional<std::vector<Tensor>>& inputs,
    const phi::TensorArgDef& target_args_def,
    const TransformFlag& transform_flag);

85 86 87 88 89 90 91 92 93 94 95
// Only support transfering place for SelectedRows
std::shared_ptr<phi::SelectedRows> PrepareDataForSelectedRows(
    const Tensor& input,
    const phi::TensorArgDef& target_args_def,
    const TransformFlag& transform_flag);

paddle::optional<phi::SelectedRows> PrepareDataForSelectedRows(
    const paddle::optional<Tensor>& input,
    const phi::TensorArgDef& target_args_def,
    const TransformFlag& transform_flag);

96 97 98 99 100 101 102 103 104 105 106 107
void TransDataBackend(const phi::DenseTensor* tensor,
                      Backend target_backend,
                      phi::DenseTensor* out);

void TransDataBackend(const std::vector<phi::DenseTensor*>& tensor,
                      Backend target_backend,
                      std::vector<phi::DenseTensor*> out);

void TransDataBackend(const phi::SelectedRows* tensor,
                      Backend target_backend,
                      phi::SelectedRows* out);

108 109
}  // namespace experimental
}  // namespace paddle