phi_utils.cc 9.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <sstream>

17
#include "paddle/fluid/framework/convert_utils.h"
18
#include "paddle/fluid/framework/phi_utils.h"
19 20

#include "paddle/fluid/framework/lod_tensor.h"
Z
Zeng Jinle 已提交
21
#include "paddle/fluid/framework/op_info.h"
22
#include "paddle/fluid/framework/selected_rows_utils.h"
23 24
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/string/string_helper.h"
25 26 27
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/compat/op_utils.h"
#include "paddle/phi/core/kernel_factory.h"
28
#include "paddle/phi/core/type_defs.h"
29 30 31 32

namespace paddle {
namespace framework {

Z
Zeng Jinle 已提交
33 34 35 36 37 38 39 40 41 42 43
class KernelArgsNameMakerByOpProto : public KernelArgsNameMaker {
 public:
  explicit KernelArgsNameMakerByOpProto(
      const framework::proto::OpProto* op_proto)
      : op_proto_(op_proto) {
    PADDLE_ENFORCE_NOT_NULL(op_proto_, platform::errors::InvalidArgument(
                                           "Op proto cannot be nullptr."));
  }

  ~KernelArgsNameMakerByOpProto() {}

C
Chen Weihang 已提交
44 45 46
  const paddle::small_vector<const char*>& GetInputArgsNames() override;
  const paddle::small_vector<const char*>& GetOutputArgsNames() override;
  const paddle::small_vector<const char*>& GetAttrsArgsNames() override;
Z
Zeng Jinle 已提交
47

48
  phi::KernelSignature GetKernelSignature();
Z
Zeng Jinle 已提交
49 50 51 52 53 54 55

 private:
  DISABLE_COPY_AND_ASSIGN(KernelArgsNameMakerByOpProto);

 private:
  const framework::proto::OpProto* op_proto_;

C
Chen Weihang 已提交
56 57 58
  paddle::small_vector<const char*> input_names_;
  paddle::small_vector<const char*> output_names_;
  paddle::small_vector<const char*> attr_names_;
Z
Zeng Jinle 已提交
59 60
};

61
OpKernelType TransPhiKernelKeyToOpKernelType(const phi::KernelKey& kernel_key) {
62
  proto::VarType::Type data_type =
63
      paddle::framework::TransToProtoVarType(kernel_key.dtype());
64
  // no need to set current device id here
65
  platform::Place place = phi::TransToPhiPlace(kernel_key.backend(), false);
66
  DataLayout data_layout = kernel_key.layout();
67
  LibraryType library_type = LibraryType::kPlain;
68
  if (kernel_key.backend() == phi::Backend::MKLDNN) {
69
    library_type = LibraryType::kMKLDNN;
70
  } else if (kernel_key.backend() == phi::Backend::GPUDNN) {
71
    library_type = LibraryType::kCUDNN;
72 73
  } else if (kernel_key.backend() == phi::Backend::KPS) {
    library_type = LibraryType::kKP;
74 75 76 77 78 79 80
  } else {
    // do nothing
  }
  // TODO(chenweihang): the customized_type_value is lost
  return OpKernelType(data_type, place, data_layout, library_type);
}

81
phi::KernelKey TransOpKernelTypeToPhiKernelKey(
82
    const OpKernelType& kernel_type) {
83
  phi::Backend backend = phi::TransToPhiBackend(kernel_type.place_);
84
  if (kernel_type.library_type_ == LibraryType::kMKLDNN) {
85
    backend = phi::Backend::MKLDNN;
86
  } else if (kernel_type.library_type_ == LibraryType::kCUDNN) {
87
    backend = phi::Backend::GPUDNN;
88 89
  } else if (kernel_type.library_type_ == LibraryType::kKP) {
    backend = phi::Backend::KPS;
90
  } else {
91
    // do nothing
92
  }
93
  paddle::experimental::DataLayout layout = kernel_type.data_layout_;
94
  paddle::experimental::DataType dtype =
95
      paddle::framework::TransToPhiDataType(kernel_type.data_type_);
96
  return phi::KernelKey(backend, layout, dtype);
97 98
}

99 100 101
phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key,
                             const phi::KernelKey& kernel_key,
                             const framework::OperatorBase& op) {
102 103 104
#ifdef PADDLE_WITH_XPU
  if (platform::is_xpu_place(expected_kernel_key.place_) ||
      paddle::platform::is_in_xpu_black_list(op.Type())) {
105
    VLOG(3) << "phi missing XPU kernel: " << op.Type()
106
            << ", expected_kernel_key:" << expected_kernel_key
107
            << ", fallbacking to CPU one!";
108 109
    return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(),
                          kernel_key.dtype());
110 111 112 113
  }
#endif
#ifdef PADDLE_WITH_ASCEND_CL
  if (platform::is_npu_place(expected_kernel_key.place_)) {
114
    VLOG(3) << "phi missing NPU kernel: " << op.Type()
115
            << ", expected_kernel_key:" << expected_kernel_key
116
            << ", fallbacking to CPU one!";
117 118
    return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(),
                          kernel_key.dtype());
119 120 121 122
  }
#endif
#ifdef PADDLE_WITH_MLU
  if (platform::is_mlu_place(expected_kernel_key.place_)) {
123
    VLOG(3) << "phi missing MLU kernel: " << op.Type()
124
            << ", expected_kernel_key:" << expected_kernel_key
125
            << ", fallbacking to CPU one!";
126 127 128 129 130 131
    return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(),
                          kernel_key.dtype());
  }
#endif
#ifdef PADDLE_WITH_IPU
  if (platform::is_ipu_place(expected_kernel_key.place_)) {
132
    VLOG(3) << "phi missing IPU kernel: " << op.Type()
133
            << ", expected_kernel_key:" << expected_kernel_key
134 135 136 137 138 139 140 141 142
            << ", fallbacking to CPU one!";
    return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(),
                          kernel_key.dtype());
  }
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
  if (platform::is_custom_place(expected_kernel_key.place_)) {
    VLOG(3) << "phi missing " << expected_kernel_key.place_.GetDeviceType()
            << " kernel: " << op.Type()
143
            << ", expected_kernel_key:" << expected_kernel_key
144
            << ", fallbacking to CPU one!";
145 146
    return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(),
                          kernel_key.dtype());
147 148
  }
#endif
149
  return phi::KernelKey();
150 151
}

C
Chen Weihang 已提交
152
const paddle::small_vector<const char*>&
153 154 155 156 157 158 159 160
KernelArgsNameMakerByOpProto::GetInputArgsNames() {
  for (int i = 0; i < op_proto_->inputs_size(); ++i) {
    auto& in = op_proto_->inputs()[i];
    auto& in_name = in.name();
    if ((in.has_extra() && in.extra()) || (in.has_quant() && in.quant())) {
      continue;
    }
    // If contains dispensable input, we should override the
161
    // OpArgumentMapping method self in phi/ops/compat dir
162 163 164
    if (in.has_dispensable() && in.dispensable()) {
      continue;
    }
165 166 167 168 169 170 171 172
    input_names_.emplace_back(in_name.c_str());
  }
  if (VLOG_IS_ON(10)) {
    std::ostringstream sout;
    sout << "PhiKernel inputs: ";
    std::copy(input_names_.begin(), input_names_.end(),
              std::ostream_iterator<const char*>(sout, ", "));
    VLOG(10) << sout.str();
173 174 175 176
  }
  return input_names_;
}

C
Chen Weihang 已提交
177
const paddle::small_vector<const char*>&
178 179 180 181
KernelArgsNameMakerByOpProto::GetOutputArgsNames() {
  for (int i = 0; i < op_proto_->outputs_size(); ++i) {
    auto& out = op_proto_->outputs()[i];
    auto& out_name = out.name();
182 183 184
    if ((out.has_extra() && out.extra()) || (out.has_quant() && out.quant())) {
      continue;
    }
185 186 187 188 189 190 191 192
    output_names_.emplace_back(out_name.c_str());
  }
  if (VLOG_IS_ON(10)) {
    std::ostringstream sout;
    sout << "PhiKernel outputs: ";
    std::copy(output_names_.begin(), output_names_.end(),
              std::ostream_iterator<const char*>(sout, ", "));
    VLOG(10) << sout.str();
193 194 195 196
  }
  return output_names_;
}

C
Chen Weihang 已提交
197
const paddle::small_vector<const char*>&
198 199 200 201
KernelArgsNameMakerByOpProto::GetAttrsArgsNames() {
  for (int i = 0; i < op_proto_->attrs_size(); ++i) {
    auto& attr = op_proto_->attrs()[i];
    auto& attr_name = attr.name();
202 203 204 205
    if (attr_name == "use_mkldnn" || attr_name == "use_cudnn" ||
        attr_name == "op_role" || attr_name == "op_role_var" ||
        attr_name == "op_namescope" || attr_name == "op_callstack" ||
        attr_name == "op_device") {
206 207 208 209 210 211
      continue;
    }
    if ((attr.has_extra() && attr.extra()) ||
        (attr.has_quant() && attr.quant())) {
      continue;
    }
212 213 214 215 216 217 218 219
    attr_names_.emplace_back(attr_name.c_str());
  }
  if (VLOG_IS_ON(10)) {
    std::ostringstream sout;
    sout << "PhiKernel attributes: ";
    std::copy(attr_names_.begin(), attr_names_.end(),
              std::ostream_iterator<const char*>(sout, ", "));
    VLOG(10) << sout.str();
220 221 222 223
  }
  return attr_names_;
}

224 225 226 227
phi::KernelSignature KernelArgsNameMakerByOpProto::GetKernelSignature() {
  return phi::KernelSignature(
      phi::TransToPhiKernelName(op_proto_->type()).c_str(), GetInputArgsNames(),
      GetAttrsArgsNames(), GetOutputArgsNames());
228 229
}

230 231 232 233 234 235 236
std::once_flag kernel_sig_map_init_flag;

void InitDefaultKernelSignatureMap() {
  std::call_once(kernel_sig_map_init_flag, [] {
    for (const auto& pair : paddle::framework::OpInfoMap::Instance().map()) {
      const auto& op_type = pair.first;
      const auto* op_proto = pair.second.proto_;
237
      if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(op_type) &&
238 239
          op_proto) {
        paddle::framework::KernelArgsNameMakerByOpProto maker(op_proto);
240
        VLOG(10) << "Register `" << op_type << "` kernel signature:";
241
        phi::DefaultKernelSignatureMap::Instance().Insert(
242 243 244 245 246 247
            op_type, std::move(maker.GetKernelSignature()));
      }
    }
  });
}

248
static void SetAllocationForUninitializedDenseTensor(
249
    phi::DenseTensor* dense_tensor, const platform::Place& place) {
250 251 252 253 254 255 256 257 258
  int dtype_size = dense_tensor->dtype() == DataType::UNDEFINED
                       ? 0
                       : experimental::SizeOf(dense_tensor->dtype());
  int64_t numels = product(dense_tensor->dims());
  numels = numels < 0 ? 0 : numels;
  auto tmp_allocation_ptr = memory::Alloc(place, numels * dtype_size);
  auto& deleter = tmp_allocation_ptr.get_deleter();
  auto* allocation_ptr = tmp_allocation_ptr.release();
  auto shared_allocation =
259
      std::shared_ptr<phi::Allocation>(allocation_ptr, deleter);
260 261 262 263

  dense_tensor->ResetHolder(shared_allocation);
}

264 265
}  // namespace framework
}  // namespace paddle