diff --git a/paddle/framework/op_kernel_type.h b/paddle/framework/op_kernel_type.h index 45bbbe580d52652a44b913e6d1b7313c6b4e9361..a1dea0d9d864881ef1f60b117dfaa02da3aa4275 100644 --- a/paddle/framework/op_kernel_type.h +++ b/paddle/framework/op_kernel_type.h @@ -22,33 +22,23 @@ limitations under the License. */ namespace paddle { namespace framework { -/* -Refer to https://stackoverflow.com/questions/35985960/ -c-why-is-boosthash-combine-the-best-way-to-combine-hash-values -*/ -template -inline void HashCombine(const T& v, std::size_t* seed) { - std::hash hasher; - *seed ^= hasher(v) + 0x9e3779b9 + (*seed << 6) + (*seed >> 2); -} - struct OpKernelType { struct Hash { size_t operator()(const OpKernelType& key) const { - int place = key.place_.which(); - int data_type = static_cast(key.data_type_); - int data_layout = static_cast(key.data_layout_); - int library_type = static_cast(key.library_type_); - - size_t seed = 0; - HashCombine(place, &seed); - HashCombine(data_type, &seed); - HashCombine(data_layout, &seed); - HashCombine(library_type, &seed); - return seed; + int place = key.place_.which() + (1 << LEFT_SHIFT); + int data_type = + static_cast(key.data_type_) + (1 << (LEFT_SHIFT + 1)); + int data_layout = + static_cast(key.data_layout_) + (1 << (LEFT_SHIFT + 2)); + int library_type = + static_cast(key.library_type_) + (1 << (LEFT_SHIFT + 3)); + std::hash hasher; + return hasher(place + data_type + data_layout + library_type); } }; + // place, data_type, library_type kinds less than 2^8 + constexpr static int LEFT_SHIFT = 8; proto::DataType data_type_; DataLayout data_layout_; platform::Place place_; diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index 1d46ce5c7031c2a27dde42c838ff444ce4ac6f54..9b958f7c920a32c9208f3dfd3ff54ac9620da9e7 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -137,11 +137,11 @@ class DeviceContextPool { private: static DeviceContextPool* pool; + constexpr static int LEFT_SHIFT = 8; struct Hash { std::hash hash_; size_t operator()(const platform::Place& place) const { - int pre_hash = place.which() - << (sizeof(int) * 8 - NUM_PLACE_TYPE_LIMIT_IN_BIT); + int pre_hash = place.which() + (1 << LEFT_SHIFT); if (platform::is_gpu_place(place)) { pre_hash += boost::get(place).GetDeviceId(); }