diff --git a/paddle/fluid/platform/place.h b/paddle/fluid/platform/place.h index daa70e943ded1909785458833349e4f8091847d1..4e7e157e3e2b4c021b5f09d3483bf41782a5ae8d 100644 --- a/paddle/fluid/platform/place.h +++ b/paddle/fluid/platform/place.h @@ -76,7 +76,24 @@ struct IsCUDAPinnedPlace : public boost::static_visitor { bool operator()(const CUDAPinnedPlace &cuda_pinned) const { return true; } }; -typedef boost::variant Place; +class Place : public boost::variant { + private: + using PlaceBase = boost::variant; + + public: + Place() = default; + Place(const CPUPlace &cpu_place) : PlaceBase(cpu_place) {} // NOLINT + Place(const CUDAPlace &cuda_place) : PlaceBase(cuda_place) {} // NOLINT + Place(const CUDAPinnedPlace &cuda_pinned_place) // NOLINT + : PlaceBase(cuda_pinned_place) {} + + bool operator<(const Place &place) const { + return PlaceBase::operator<(static_cast(place)); + } + bool operator==(const Place &place) const { + return PlaceBase::operator==(static_cast(place)); + } +}; using PlaceList = std::vector;