diff --git a/paddle/framework/library_type.h b/paddle/framework/library_type.h index 68e9cabb667a5b2421fad8333d8e1be7bfa57002..49b273656bf57f183209e3d0996358da28ec0e7a 100644 --- a/paddle/framework/library_type.h +++ b/paddle/framework/library_type.h @@ -20,7 +20,7 @@ namespace framework { // For more details about the design of LibraryType, Please refer to // https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/operator_kernel_type.md#library -enum LibraryType { kPlain = 0; kMKLDNN = 1; kCUDNN = 2; } +enum LibraryType { kPlain = 0, kMKLDNN = 1, kCUDNN = 2 }; } // namespace } // framework diff --git a/paddle/framework/op_kernel_type.h b/paddle/framework/op_kernel_type.h new file mode 100644 index 0000000000000000000000000000000000000000..45bbbe580d52652a44b913e6d1b7313c6b4e9361 --- /dev/null +++ b/paddle/framework/op_kernel_type.h @@ -0,0 +1,82 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/data_layout.h" +#include "paddle/framework/data_type.h" +#include "paddle/framework/library_type.h" +#include "paddle/platform/place.h" + +namespace paddle { +namespace framework { + +/* +Refer to https://stackoverflow.com/questions/35985960/ +c-why-is-boosthash-combine-the-best-way-to-combine-hash-values +*/ +template +inline void HashCombine(const T& v, std::size_t* seed) { + std::hash hasher; + *seed ^= hasher(v) + 0x9e3779b9 + (*seed << 6) + (*seed >> 2); +} + +struct OpKernelType { + struct Hash { + size_t operator()(const OpKernelType& key) const { + int place = key.place_.which(); + int data_type = static_cast(key.data_type_); + int data_layout = static_cast(key.data_layout_); + int library_type = static_cast(key.library_type_); + + size_t seed = 0; + HashCombine(place, &seed); + HashCombine(data_type, &seed); + HashCombine(data_layout, &seed); + HashCombine(library_type, &seed); + return seed; + } + }; + + proto::DataType data_type_; + DataLayout data_layout_; + platform::Place place_; + LibraryType library_type_; + + OpKernelType(proto::DataType data_type, platform::Place place, + DataLayout data_layout = DataLayout::kAnyLayout, + LibraryType library_type = LibraryType::kPlain) + : data_type_(data_type), + data_layout_(data_layout), + place_(place), + library_type_(library_type) {} + + OpKernelType(proto::DataType data_type, + const platform::DeviceContext& dev_ctx, + DataLayout data_layout = DataLayout::kAnyLayout, + LibraryType library_type = LibraryType::kPlain) + : data_type_(data_type), + data_layout_(data_layout), + place_(dev_ctx.GetPlace()), + library_type_(library_type) {} + + bool operator==(const OpKernelType& o) const { + return platform::places_are_same_class(place_, o.place_) && + data_type_ == o.data_type_ && data_layout_ == o.data_layout_ && + library_type_ == o.library_type_; + } +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 5d38ef5bebc33039e0391069ae87a974fff537af..06184f6ba968c438f6baa571d7a5c12a69109c84 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -243,8 +243,9 @@ std::vector ExecutionContext::MultiOutput( } std::ostream& operator<<(std::ostream& os, const OpKernelType& kernel_key) { - os << "place[" << kernel_key.place_ << "]:data_type[" << kernel_key.data_type_ - << "]"; + os << "data_type[" << kernel_key.data_type_ << "]:data_layout[" + << kernel_key.data_layout_ << "]:place[" << kernel_key.place_ + << "]:library_type[" << kernel_key.library_type_ << "]"; return os; } diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index ef750aff1bae2540690ccc82b1105449344bf4ab..aba34c5bcb81c85db21e9d82894fc0b937c3c060 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -23,15 +23,14 @@ limitations under the License. */ #include "glog/logging.h" // For VLOG #include "paddle/framework/attribute.h" #include "paddle/framework/block_desc.h" -#include "paddle/framework/data_type.h" #include "paddle/framework/framework.pb.h" #include "paddle/framework/lod_tensor.h" #include "paddle/framework/op_info.h" +#include "paddle/framework/op_kernel_type.h" #include "paddle/framework/scope.h" #include "paddle/framework/selected_rows.h" #include "paddle/framework/tensor.h" #include "paddle/platform/device_context.h" -#include "paddle/platform/place.h" #include "paddle/platform/variant.h" #include "paddle/utils/Error.h" @@ -343,34 +342,6 @@ class OpKernel : public OpKernelBase { using ELEMENT_TYPE = T; }; -struct OpKernelType { - struct Hash { - std::hash hash_; - size_t operator()(const OpKernelType& key) const { - int place = key.place_.which(); - int data_type = static_cast(key.data_type_); - int pre_hash = data_type << NUM_PLACE_TYPE_LIMIT_IN_BIT | - (place & ((1 << NUM_PLACE_TYPE_LIMIT_IN_BIT) - 1)); - return hash_(pre_hash); - } - }; - - platform::Place place_; - proto::DataType data_type_; - - OpKernelType(proto::DataType data_type, platform::Place place) - : place_(place), data_type_(data_type) {} - - OpKernelType(proto::DataType data_type, - const platform::DeviceContext& dev_ctx) - : place_(dev_ctx.GetPlace()), data_type_(data_type) {} - - bool operator==(const OpKernelType& o) const { - return platform::places_are_same_class(place_, o.place_) && - data_type_ == o.data_type_; - } -}; - class OperatorWithKernel : public OperatorBase { public: using OpKernelMap = diff --git a/paddle/platform/place.h b/paddle/platform/place.h index 6bff2d4d9cd7eefaa7212af2a1287e9aaff7d684..daeafbbcd780aaeab20c8fcbbeed60a587e0049b 100644 --- a/paddle/platform/place.h +++ b/paddle/platform/place.h @@ -70,18 +70,8 @@ struct IsMKLDNNPlace : public boost::static_visitor { bool operator()(const CUDNNPlace &) const { return false; } }; -// Define the max number of Place in bit length. i.e., the max number of places -// should be less equal than 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT) -#define NUM_PLACE_TYPE_LIMIT_IN_BIT 4 - typedef boost::variant Place; -// static check number of place types is less equal than -// 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT) -BOOST_MPL_ASSERT((boost::mpl::less_equal< - Place::types::size, - boost::mpl::long_<1 << NUM_PLACE_TYPE_LIMIT_IN_BIT>>)); - void set_place(const Place &); const Place &get_place();