未验证 提交 7e659a0a 编写于 作者: Q Qiao Longfei 提交者: GitHub

Merge pull request #6932 from dzhwinter/fix/kernelkey

"remove hash combine"
...@@ -22,33 +22,23 @@ limitations under the License. */ ...@@ -22,33 +22,23 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
/*
Refer to https://stackoverflow.com/questions/35985960/
c-why-is-boosthash-combine-the-best-way-to-combine-hash-values
*/
template <class T>
inline void HashCombine(const T& v, std::size_t* seed) {
std::hash<T> hasher;
*seed ^= hasher(v) + 0x9e3779b9 + (*seed << 6) + (*seed >> 2);
}
struct OpKernelType { struct OpKernelType {
struct Hash { struct Hash {
size_t operator()(const OpKernelType& key) const { size_t operator()(const OpKernelType& key) const {
int place = key.place_.which(); int place = key.place_.which() + (1 << LEFT_SHIFT);
int data_type = static_cast<int>(key.data_type_); int data_type =
int data_layout = static_cast<int>(key.data_layout_); static_cast<int>(key.data_type_) + (1 << (LEFT_SHIFT + 1));
int library_type = static_cast<int>(key.library_type_); int data_layout =
static_cast<int>(key.data_layout_) + (1 << (LEFT_SHIFT + 2));
size_t seed = 0; int library_type =
HashCombine(place, &seed); static_cast<int>(key.library_type_) + (1 << (LEFT_SHIFT + 3));
HashCombine(data_type, &seed); std::hash<int> hasher;
HashCombine(data_layout, &seed); return hasher(place + data_type + data_layout + library_type);
HashCombine(library_type, &seed);
return seed;
} }
}; };
// place, data_type, library_type kinds less than 2^8
constexpr static int LEFT_SHIFT = 8;
proto::DataType data_type_; proto::DataType data_type_;
DataLayout data_layout_; DataLayout data_layout_;
platform::Place place_; platform::Place place_;
......
...@@ -137,11 +137,11 @@ class DeviceContextPool { ...@@ -137,11 +137,11 @@ class DeviceContextPool {
private: private:
static DeviceContextPool* pool; static DeviceContextPool* pool;
constexpr static int LEFT_SHIFT = 8;
struct Hash { struct Hash {
std::hash<int> hash_; std::hash<int> hash_;
size_t operator()(const platform::Place& place) const { size_t operator()(const platform::Place& place) const {
int pre_hash = place.which() int pre_hash = place.which() + (1 << LEFT_SHIFT);
<< (sizeof(int) * 8 - NUM_PLACE_TYPE_LIMIT_IN_BIT);
if (platform::is_gpu_place(place)) { if (platform::is_gpu_place(place)) {
pre_hash += boost::get<platform::GPUPlace>(place).GetDeviceId(); pre_hash += boost::get<platform::GPUPlace>(place).GetDeviceId();
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册