diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 4e81d1eaa9dfc9792140411f90aab58087146bef..7d563a3c059874de7c4dc8c4d13ac7dc9139bf47 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -451,8 +451,8 @@ class OperatorWithKernel : public OperatorBase { size_t operator()(const OpKernelKey& key) const { int place = key.place_.which(); int data_type = static_cast(key.data_type_); - // NOTE: Number of places limit to 16. - int pre_hash = data_type << 4 | (place & 0x0F); + int pre_hash = data_type << NUM_PLACE_TYPE_LIMIT_IN_BIT | + (place & ((1 << NUM_PLACE_TYPE_LIMIT_IN_BIT) - 1)); return hash_(pre_hash); } }; diff --git a/paddle/platform/place.h b/paddle/platform/place.h index 1117476bb37f1b0f3876c55e610803d5ee2558ce..0efc6932349a5b3ad295d195a16737a642e18943 100644 --- a/paddle/platform/place.h +++ b/paddle/platform/place.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include + #include "paddle/platform/variant.h" namespace paddle { @@ -46,8 +47,18 @@ struct IsGPUPlace : public boost::static_visitor { bool operator()(const GPUPlace &gpu) const { return true; } }; +// Define the max number of Place in bit length. i.e., the max number of places +// should be less equal than 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT) +#define NUM_PLACE_TYPE_LIMIT_IN_BIT 4 + typedef boost::variant Place; +// static check number of place types is less equal than +// 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT) +BOOST_MPL_ASSERT((boost::mpl::less_equal< + Place::types::size, + boost::mpl::long_<1 << NUM_PLACE_TYPE_LIMIT_IN_BIT>>)); + void set_place(const Place &); const Place &get_place(); diff --git a/paddle/platform/variant.h b/paddle/platform/variant.h index c2257af1b5dd1a1e284979bf17e1a947072baa85..16ee00efe7a9b0406f8459e19a55e1e1b9ca7419 100644 --- a/paddle/platform/variant.h +++ b/paddle/platform/variant.h @@ -29,4 +29,6 @@ #endif #endif +#include +#include #include