kernel_op_desc.cc 3.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/infrt/dialect/phi/pass/kernel_op_desc.h"
#include <glog/logging.h>
17
#include "paddle/infrt/dialect/phi/data_type.h"
18
#include "paddle/phi/core/type_defs.h"
19
#include "paddle/phi/kernels/declarations.h"
20

21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
namespace infrt {

std::string getPhiTargetPrefix(TargetType target) {
  switch (target) {
    case TargetType::CPU:
      return "phi_cpu.";
    case TargetType::GPU:
      return "phi_gpu.";
    default:
      LOG(FATAL) << "UnSupported target type !";
      return std::string();
  }
}
std::string getPhiPrecisionSuffix(PrecisionType precision) {
  switch (precision) {
    case PrecisionType::FLOAT32:
      return ".float32";
    case PrecisionType::FLOAT16:
      return ".float16";
    case PrecisionType::FLOAT64:
      return ".float64";
    case PrecisionType::UINT8:
      return ".uint8";
    case PrecisionType::INT8:
      return ".int8";
    case PrecisionType::INT16:
      return ".int16";
    case PrecisionType::INT32:
      return ".int32";
    case PrecisionType::INT64:
      return ".int64";
    case PrecisionType::COMPLEX64:
      return ".complex64";
    case PrecisionType::COMPLEX128:
      return ".complex128";
    case PrecisionType::BOOL:
      return ".bool";
    default:
      LOG(FATAL) << "UnSupported precision type !";
      return std::string();
  }
}
std::string getPhiLayoutSuffix(LayoutType layout) {
  switch (layout) {
    case LayoutType::NCHW:
      return ".nchw";
    case LayoutType::NHWC:
      return ".nhwc";
    case LayoutType::ANY:
      return ".any";
    default:
      LOG(FATAL) << "UnSupported layout type !";
      return std::string();
  }
}

王明冬 已提交
77
std::vector<PhiKernelDesc> GetCandidateKernels(
78 79 80 81 82
    std::string name, const std::vector<Place>& valid_palces) {
  std::vector<PhiKernelDesc> candidate_kernels;
  PhiKernelDesc phi_kernel_desc;
  phi::KernelKeyMap kernel_key_map =
      phi::KernelFactory::Instance().SelectKernelMap(name);
83
  for (Place place : valid_palces) {
84
    phi::KernelKey kernel_key = ConvertPlaceToPhi(place);
85 86 87 88 89
    if (kernel_key_map.find(kernel_key) == kernel_key_map.end()) {
      kernel_key = phi::KernelKey(kernel_key.backend(),
                                  phi::DataLayout::ALL_LAYOUT,
                                  kernel_key.dtype());
      if (kernel_key_map.find(kernel_key) == kernel_key_map.end()) continue;
90
      place.layout = LayoutType::ANY;
91
    }
王明冬 已提交
92 93 94
    phi_kernel_desc.kernel_type = place;
    phi_kernel_desc.input_types.clear();
    phi_kernel_desc.output_types.clear();
95
    phi::KernelArgsDef args_def = kernel_key_map.at(kernel_key).args_def();
C
Chen Weihang 已提交
96
    const paddle::small_vector<phi::TensorArgDef, phi::kInputSmallVectorSize>&
97
        input_arg = args_def.input_defs();
C
Chen Weihang 已提交
98
    const paddle::small_vector<phi::TensorArgDef, phi::kOutputSmallVectorSize>&
99
        output_arg = args_def.output_defs();
100
    for (auto tensor_arg : input_arg) {
王明冬 已提交
101
      phi_kernel_desc.input_types.emplace_back(ConvertPlaceFromPhi(tensor_arg));
102 103
    }
    for (auto tensor_arg : output_arg) {
王明冬 已提交
104 105
      phi_kernel_desc.output_types.emplace_back(
          ConvertPlaceFromPhi(tensor_arg));
106 107 108 109 110 111 112
    }
    candidate_kernels.emplace_back(phi_kernel_desc);
  }
  return candidate_kernels;
}

}  // namespace infrt