op_kernel_type.h 3.7 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
QI JUN 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17
#include <string>
Y
Yi Wang 已提交
18 19 20 21 22
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/library_type.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"
Q
QI JUN 已提交
23 24 25 26

namespace paddle {
namespace framework {

X
Xin Pan 已提交
27 28 29
class OpKernelType {
 public:
  constexpr static int kDefaultCustomizedTypeValue = 0;
Q
QI JUN 已提交
30

X
Xin Pan 已提交
31 32 33 34 35 36
  // In total should be smaller than 64.
  constexpr static int kPlaceBits = 4;
  constexpr static int kPrimaryDTypeBits = 8;
  constexpr static int kLayoutBits = 4;
  constexpr static int kLibBits = 4;
  constexpr static int kCustomizeBits = 4;
Q
QI JUN 已提交
37

38
  OpKernelType(proto::VarType::Type data_type, platform::Place place,
Q
QI JUN 已提交
39
               DataLayout data_layout = DataLayout::kAnyLayout,
X
Xin Pan 已提交
40 41
               LibraryType library_type = LibraryType::kPlain,
               int customized_type_value = kDefaultCustomizedTypeValue)
Q
QI JUN 已提交
42 43 44
      : data_type_(data_type),
        data_layout_(data_layout),
        place_(place),
X
Xin Pan 已提交
45 46
        library_type_(library_type),
        customized_type_value_(customized_type_value) {}
Q
QI JUN 已提交
47

48
  OpKernelType(proto::VarType::Type data_type,
Q
QI JUN 已提交
49 50
               const platform::DeviceContext& dev_ctx,
               DataLayout data_layout = DataLayout::kAnyLayout,
X
Xin Pan 已提交
51 52
               LibraryType library_type = LibraryType::kPlain,
               int customized_type_value = kDefaultCustomizedTypeValue)
Q
QI JUN 已提交
53 54 55
      : data_type_(data_type),
        data_layout_(data_layout),
        place_(dev_ctx.GetPlace()),
X
Xin Pan 已提交
56 57 58 59 60 61 62 63
        library_type_(library_type),
        customized_type_value_(customized_type_value) {}

  virtual ~OpKernelType() {}

  struct Hash {
    size_t operator()(const OpKernelType& key) const;
  };
Q
QI JUN 已提交
64

65 66
  size_t hash_key() const { return Hash()(*this); }

X
Xin Pan 已提交
67
  bool operator==(const OpKernelType& o) const;
Q
QI JUN 已提交
68 69

  bool operator!=(const OpKernelType& o) const { return !(*this == o); }
X
Xin Pan 已提交
70 71 72 73 74 75

  proto::VarType::Type data_type_;
  DataLayout data_layout_;
  platform::Place place_;
  LibraryType library_type_;
  int customized_type_value_;
Q
QI JUN 已提交
76 77
};

Q
qiaolongfei 已提交
78 79 80 81 82 83 84 85
inline std::ostream& operator<<(std::ostream& os,
                                const OpKernelType& kernel_key) {
  os << "data_type[" << kernel_key.data_type_ << "]:data_layout["
     << kernel_key.data_layout_ << "]:place[" << kernel_key.place_
     << "]:library_type[" << kernel_key.library_type_ << "]";
  return os;
}

Q
QI JUN 已提交
86 87 88 89 90 91
inline std::string KernelTypeToString(const OpKernelType& kernel_key) {
  std::ostringstream stream;
  stream << kernel_key;
  return stream.str();
}

92
inline bool NeedTransformLayout(const DataLayout& l, const DataLayout& r) {
M
mozga-intel 已提交
93 94 95 96 97 98 99 100
  bool ret =
      (l != DataLayout::kAnyLayout && r != DataLayout::kAnyLayout && l != r);
#ifdef PADDLE_WITH_MKLDNN
  // Layout transform needed for either non-MKLDNN to MKLDNN or vice versa
  ret |= (l != DataLayout::kMKLDNN && r == DataLayout::kMKLDNN);
  ret |= (l == DataLayout::kMKLDNN && r != DataLayout::kMKLDNN);
#endif
  return ret;
101 102
}

Y
yuyang18 已提交
103
inline bool NeedTransform(const OpKernelType& l, const OpKernelType& r) {
104
  return (!platform::places_are_same_class(l.place_, r.place_)) ||
105 106
         (l.data_type_ != r.data_type_) ||
         NeedTransformLayout(l.data_layout_, r.data_layout_);
107 108
}

Q
QI JUN 已提交
109 110
}  // namespace framework
}  // namespace paddle