phi_utils.cc 9.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15 16
#include "paddle/fluid/framework/phi_utils.h"

17 18
#include <sstream>

19
#include "paddle/fluid/framework/convert_utils.h"
20
#include "paddle/fluid/framework/lod_tensor.h"
Z
Zeng Jinle 已提交
21
#include "paddle/fluid/framework/op_info.h"
22
#include "paddle/fluid/framework/selected_rows_utils.h"
23 24
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/string/string_helper.h"
25 26 27
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/compat/op_utils.h"
#include "paddle/phi/core/kernel_factory.h"
28
#include "paddle/phi/core/type_defs.h"
29 30 31 32

namespace paddle {
namespace framework {

Z
Zeng Jinle 已提交
33 34 35 36 37 38 39 40 41 42 43
class KernelArgsNameMakerByOpProto : public KernelArgsNameMaker {
 public:
  explicit KernelArgsNameMakerByOpProto(
      const framework::proto::OpProto* op_proto)
      : op_proto_(op_proto) {
    PADDLE_ENFORCE_NOT_NULL(op_proto_, platform::errors::InvalidArgument(
                                           "Op proto cannot be nullptr."));
  }

  ~KernelArgsNameMakerByOpProto() {}

C
Chen Weihang 已提交
44 45 46
  const paddle::small_vector<const char*>& GetInputArgsNames() override;
  const paddle::small_vector<const char*>& GetOutputArgsNames() override;
  const paddle::small_vector<const char*>& GetAttrsArgsNames() override;
Z
Zeng Jinle 已提交
47

48
  phi::KernelSignature GetKernelSignature();
Z
Zeng Jinle 已提交
49 50 51 52 53 54 55

 private:
  DISABLE_COPY_AND_ASSIGN(KernelArgsNameMakerByOpProto);

 private:
  const framework::proto::OpProto* op_proto_;

C
Chen Weihang 已提交
56 57 58
  paddle::small_vector<const char*> input_names_;
  paddle::small_vector<const char*> output_names_;
  paddle::small_vector<const char*> attr_names_;
Z
Zeng Jinle 已提交
59 60
};

61
OpKernelType TransPhiKernelKeyToOpKernelType(const phi::KernelKey& kernel_key) {
62
  proto::VarType::Type data_type =
63
      paddle::framework::TransToProtoVarType(kernel_key.dtype());
64
  // no need to set current device id here
65
  platform::Place place = phi::TransToPhiPlace(kernel_key.backend(), false);
66
  DataLayout data_layout = kernel_key.layout();
67
  LibraryType library_type = LibraryType::kPlain;
68
  if (kernel_key.backend() == phi::Backend::MKLDNN) {
69
    library_type = LibraryType::kMKLDNN;
70
  } else if (kernel_key.backend() == phi::Backend::GPUDNN) {
71
    library_type = LibraryType::kCUDNN;
72 73
  } else if (kernel_key.backend() == phi::Backend::KPS) {
    library_type = LibraryType::kKP;
74 75 76 77 78 79 80
  } else {
    // do nothing
  }
  // TODO(chenweihang): the customized_type_value is lost
  return OpKernelType(data_type, place, data_layout, library_type);
}

81
phi::KernelKey TransOpKernelTypeToPhiKernelKey(
82
    const OpKernelType& kernel_type) {
83
  phi::Backend backend = phi::TransToPhiBackend(kernel_type.place_);
84 85 86 87 88 89 90 91 92 93 94 95
  switch (kernel_type.library_type_) {
    case LibraryType::kCUDNN:
      backend = phi::Backend::GPUDNN;
      break;
    case LibraryType::kMKLDNN:
      backend = phi::Backend::MKLDNN;
      break;
    case LibraryType::kKP:
      backend = phi::Backend::KPS;
      break;
    default:
      break;
96
  }
97 98
  return phi::KernelKey(backend, kernel_type.data_layout_,
                        framework::TransToPhiDataType(kernel_type.data_type_));
99 100
}

101 102 103
phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key,
                             const phi::KernelKey& kernel_key,
                             const framework::OperatorBase& op) {
104 105 106
#ifdef PADDLE_WITH_XPU
  if (platform::is_xpu_place(expected_kernel_key.place_) ||
      paddle::platform::is_in_xpu_black_list(op.Type())) {
107
    VLOG(3) << "phi missing XPU kernel: " << op.Type()
108
            << ", expected_kernel_key:" << expected_kernel_key
109
            << ", fallbacking to CPU one!";
110 111
    return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(),
                          kernel_key.dtype());
112 113 114 115
  }
#endif
#ifdef PADDLE_WITH_ASCEND_CL
  if (platform::is_npu_place(expected_kernel_key.place_)) {
116
    VLOG(3) << "phi missing NPU kernel: " << op.Type()
117
            << ", expected_kernel_key:" << expected_kernel_key
118
            << ", fallbacking to CPU one!";
119 120
    return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(),
                          kernel_key.dtype());
121 122 123 124
  }
#endif
#ifdef PADDLE_WITH_MLU
  if (platform::is_mlu_place(expected_kernel_key.place_)) {
125
    VLOG(3) << "phi missing MLU kernel: " << op.Type()
126
            << ", expected_kernel_key:" << expected_kernel_key
127
            << ", fallbacking to CPU one!";
128 129 130 131 132 133
    return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(),
                          kernel_key.dtype());
  }
#endif
#ifdef PADDLE_WITH_IPU
  if (platform::is_ipu_place(expected_kernel_key.place_)) {
134
    VLOG(3) << "phi missing IPU kernel: " << op.Type()
135
            << ", expected_kernel_key:" << expected_kernel_key
136 137 138 139 140 141 142 143 144
            << ", fallbacking to CPU one!";
    return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(),
                          kernel_key.dtype());
  }
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
  if (platform::is_custom_place(expected_kernel_key.place_)) {
    VLOG(3) << "phi missing " << expected_kernel_key.place_.GetDeviceType()
            << " kernel: " << op.Type()
145
            << ", expected_kernel_key:" << expected_kernel_key
146
            << ", fallbacking to CPU one!";
147 148
    return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(),
                          kernel_key.dtype());
149 150
  }
#endif
151
  return phi::KernelKey();
152 153
}

C
Chen Weihang 已提交
154
const paddle::small_vector<const char*>&
155 156 157 158 159 160 161 162
KernelArgsNameMakerByOpProto::GetInputArgsNames() {
  for (int i = 0; i < op_proto_->inputs_size(); ++i) {
    auto& in = op_proto_->inputs()[i];
    auto& in_name = in.name();
    if ((in.has_extra() && in.extra()) || (in.has_quant() && in.quant())) {
      continue;
    }
    // If contains dispensable input, we should override the
163
    // OpArgumentMapping method self in phi/ops/compat dir
164 165 166
    if (in.has_dispensable() && in.dispensable()) {
      continue;
    }
167 168 169 170 171 172 173 174
    input_names_.emplace_back(in_name.c_str());
  }
  if (VLOG_IS_ON(10)) {
    std::ostringstream sout;
    sout << "PhiKernel inputs: ";
    std::copy(input_names_.begin(), input_names_.end(),
              std::ostream_iterator<const char*>(sout, ", "));
    VLOG(10) << sout.str();
175 176 177 178
  }
  return input_names_;
}

C
Chen Weihang 已提交
179
const paddle::small_vector<const char*>&
180 181 182 183
KernelArgsNameMakerByOpProto::GetOutputArgsNames() {
  for (int i = 0; i < op_proto_->outputs_size(); ++i) {
    auto& out = op_proto_->outputs()[i];
    auto& out_name = out.name();
184 185 186
    if ((out.has_extra() && out.extra()) || (out.has_quant() && out.quant())) {
      continue;
    }
187 188 189 190 191 192 193 194
    output_names_.emplace_back(out_name.c_str());
  }
  if (VLOG_IS_ON(10)) {
    std::ostringstream sout;
    sout << "PhiKernel outputs: ";
    std::copy(output_names_.begin(), output_names_.end(),
              std::ostream_iterator<const char*>(sout, ", "));
    VLOG(10) << sout.str();
195 196 197 198
  }
  return output_names_;
}

C
Chen Weihang 已提交
199
const paddle::small_vector<const char*>&
200 201 202 203
KernelArgsNameMakerByOpProto::GetAttrsArgsNames() {
  for (int i = 0; i < op_proto_->attrs_size(); ++i) {
    auto& attr = op_proto_->attrs()[i];
    auto& attr_name = attr.name();
204 205 206 207
    if (attr_name == "use_mkldnn" || attr_name == "use_cudnn" ||
        attr_name == "op_role" || attr_name == "op_role_var" ||
        attr_name == "op_namescope" || attr_name == "op_callstack" ||
        attr_name == "op_device") {
208 209 210 211 212 213
      continue;
    }
    if ((attr.has_extra() && attr.extra()) ||
        (attr.has_quant() && attr.quant())) {
      continue;
    }
214 215 216 217 218 219 220 221
    attr_names_.emplace_back(attr_name.c_str());
  }
  if (VLOG_IS_ON(10)) {
    std::ostringstream sout;
    sout << "PhiKernel attributes: ";
    std::copy(attr_names_.begin(), attr_names_.end(),
              std::ostream_iterator<const char*>(sout, ", "));
    VLOG(10) << sout.str();
222 223 224 225
  }
  return attr_names_;
}

226 227 228 229
phi::KernelSignature KernelArgsNameMakerByOpProto::GetKernelSignature() {
  return phi::KernelSignature(
      phi::TransToPhiKernelName(op_proto_->type()).c_str(), GetInputArgsNames(),
      GetAttrsArgsNames(), GetOutputArgsNames());
230 231
}

232 233 234 235 236 237 238
std::once_flag kernel_sig_map_init_flag;

void InitDefaultKernelSignatureMap() {
  std::call_once(kernel_sig_map_init_flag, [] {
    for (const auto& pair : paddle::framework::OpInfoMap::Instance().map()) {
      const auto& op_type = pair.first;
      const auto* op_proto = pair.second.proto_;
239
      if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(op_type) &&
240 241
          op_proto) {
        paddle::framework::KernelArgsNameMakerByOpProto maker(op_proto);
242
        VLOG(10) << "Register `" << op_type << "` kernel signature:";
243
        phi::DefaultKernelSignatureMap::Instance().Insert(
244 245 246 247 248 249
            op_type, std::move(maker.GetKernelSignature()));
      }
    }
  });
}

250
static void SetAllocationForUninitializedDenseTensor(
251
    phi::DenseTensor* dense_tensor, const platform::Place& place) {
252 253 254 255 256 257 258 259 260
  int dtype_size = dense_tensor->dtype() == DataType::UNDEFINED
                       ? 0
                       : experimental::SizeOf(dense_tensor->dtype());
  int64_t numels = product(dense_tensor->dims());
  numels = numels < 0 ? 0 : numels;
  auto tmp_allocation_ptr = memory::Alloc(place, numels * dtype_size);
  auto& deleter = tmp_allocation_ptr.get_deleter();
  auto* allocation_ptr = tmp_allocation_ptr.release();
  auto shared_allocation =
261
      std::shared_ptr<phi::Allocation>(allocation_ptr, deleter);
262 263 264 265

  dense_tensor->ResetHolder(shared_allocation);
}

266 267
}  // namespace framework
}  // namespace paddle