kernel.cc 3.6 KB
Newer Older
Y
Yan Chunwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/core/kernel.h"
#include <cstdlib>
#include "lite/utils/string.h"

namespace paddle {
namespace lite {

std::string KernelBase::summary() const {
  STL::stringstream ss;
  ss << op_type() << ":" << TargetToStr(target()) << "/"
     << PrecisionToStr(precision()) << "/" << DataLayoutToStr(layout()) << "("
     << alias() << ")";
  return ss.str();
}

const Type *KernelBase::GetInputDeclType(const std::string &arg_name) const {
  CHECK(!op_type_.empty()) << "op_type should be set first";
  const auto *type = ParamTypeRegistry::Global().RetrieveInArgument(
      place(), GenParamTypeKey(), arg_name);
  CHECK(type) << "no type registered for kernel [" << op_type_
              << "] input argument [" << arg_name << "]"
              << " with key " << GenParamTypeKey();
  return type->type;
}

const Type *KernelBase::GetOutputDeclType(const std::string &arg_name) const {
  CHECK(!op_type_.empty()) << "op_type should be set first";
  const auto *type = ParamTypeRegistry::Global().RetrieveOutArgument(
      place(), GenParamTypeKey(), arg_name);
  CHECK(type) << "no type registered for kernel [" << GenParamTypeKey()
              << "] output argument [" << arg_name << "]";
  return type->type;
}

std::string KernelBase::GenParamTypeKey() const {
  STL::stringstream ss;
  ss << op_type() << "/" << alias_;
  return ss.str();
}

void KernelBase::ParseKernelType(const std::string &kernel_type,
                                 std::string *op_type,
                                 std::string *alias,
                                 Place *place) {
  auto parts = Split(kernel_type, "/");
  CHECK_EQ(parts.size(), 5);
  *op_type = parts[0];
  *alias = parts[1];

  std::string target, precision, layout;

  target = parts[2];
  precision = parts[3];
  layout = parts[4];

  place->target = static_cast<TargetType>(std::atoi(target.c_str()));
  place->precision = static_cast<PrecisionType>(std::atoi(precision.c_str()));
  place->layout = static_cast<DataLayoutType>(std::atoi(layout.c_str()));
}

std::string KernelBase::SerializeKernelType(const std::string &op_type,
                                            const std::string &alias,
                                            const Place &place) {
  STL::stringstream ss;
  ss << op_type << "/";
  ss << alias << "/";
  // We serialize the place value not the string representation here for
  // easier deserialization.
  ss << static_cast<int>(place.target) << "/";
  ss << static_cast<int>(place.precision) << "/";
  ss << static_cast<int>(place.layout);
  return ss.str();
}

bool ParamTypeRegistry::KeyCmp::operator()(
    const ParamTypeRegistry::key_t &a,
    const ParamTypeRegistry::key_t &b) const {
  return a.hash() < b.hash();
}

STL::ostream &operator<<(STL::ostream &os,
                         const ParamTypeRegistry::KernelIdTy &other) {
  std::string io_s = other.io == ParamTypeRegistry::IO::kInput ? "in" : "out";
  os << other.kernel_type << ":" << other.arg_name << ":" << io_s << ":"
     << other.place.DebugString();
  return os;
}

}  // namespace lite
}  // namespace paddle