From e0c33176460ffe322c8ff6da024628741423d020 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Thu, 14 Dec 2017 23:14:18 +0800 Subject: [PATCH] add MKLDNNPlace --- paddle/platform/place.cc | 6 ++++++ paddle/platform/place.h | 19 ++++++++++++++++++- paddle/platform/place_test.cc | 6 ++++++ 3 files changed, 30 insertions(+), 1 deletion(-) diff --git a/paddle/platform/place.cc b/paddle/platform/place.cc index 856e54df89c..16f17742858 100644 --- a/paddle/platform/place.cc +++ b/paddle/platform/place.cc @@ -23,6 +23,7 @@ class PlacePrinter : public boost::static_visitor<> { public: explicit PlacePrinter(std::ostream &os) : os_(os) {} void operator()(const CPUPlace &) { os_ << "CPUPlace"; } + void operator()(const MKLDNNPlace &) { os_ << "MKLDNNPlace"; } void operator()(const GPUPlace &p) { os_ << "GPUPlace(" << p.device << ")"; } private: @@ -38,6 +39,7 @@ const Place &get_place() { return the_default_place; } const GPUPlace default_gpu() { return GPUPlace(0); } const CPUPlace default_cpu() { return CPUPlace(); } +const MKLDNNPlace default_mkldnn() { return MKLDNNPlace(); } bool is_gpu_place(const Place &p) { return boost::apply_visitor(IsGPUPlace(), p); @@ -46,6 +48,10 @@ bool is_cpu_place(const Place &p) { return !boost::apply_visitor(IsGPUPlace(), p); } +bool is_mkldnn_place(const Place &p) { + return boost::apply_visitor(IsMKLDNNPlace(), p); +} + bool places_are_same_class(const Place &p1, const Place &p2) { return p1.which() == p2.which(); } diff --git a/paddle/platform/place.h b/paddle/platform/place.h index 5370360a7de..e745b2e8391 100644 --- a/paddle/platform/place.h +++ b/paddle/platform/place.h @@ -31,6 +31,14 @@ struct CPUPlace { inline bool operator!=(const CPUPlace &) const { return false; } }; +struct MKLDNNPlace : public CPUPlace { + MKLDNNPlace() {} + + // needed for variant equality comparison + inline bool operator==(const MKLDNNPlace &) const { return true; } + inline bool operator!=(const MKLDNNPlace &) const { return false; } +}; + struct GPUPlace { GPUPlace() : GPUPlace(0) {} explicit GPUPlace(int d) : device(d) {} @@ -45,14 +53,21 @@ struct GPUPlace { struct IsGPUPlace : public boost::static_visitor { bool operator()(const CPUPlace &) const { return false; } + bool operator()(const MKLDNNPlace &) const { return false; } bool operator()(const GPUPlace &gpu) const { return true; } }; +struct IsMKLDNNPlace : public boost::static_visitor { + bool operator()(const MKLDNNPlace &) const { return true; } + bool operator()(const CPUPlace &) const { return false; } + bool operator()(const GPUPlace &) const { return false; } +}; + // Define the max number of Place in bit length. i.e., the max number of places // should be less equal than 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT) #define NUM_PLACE_TYPE_LIMIT_IN_BIT 4 -typedef boost::variant Place; +typedef boost::variant Place; // static check number of place types is less equal than // 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT) @@ -65,9 +80,11 @@ const Place &get_place(); const GPUPlace default_gpu(); const CPUPlace default_cpu(); +const MKLDNNPlace default_mkldnn(); bool is_gpu_place(const Place &); bool is_cpu_place(const Place &); +bool is_mkldnn_place(const Place &); bool places_are_same_class(const Place &, const Place &); std::ostream &operator<<(std::ostream &, const Place &); diff --git a/paddle/platform/place_test.cc b/paddle/platform/place_test.cc index 33e2e5a439c..184af12c230 100644 --- a/paddle/platform/place_test.cc +++ b/paddle/platform/place_test.cc @@ -21,9 +21,15 @@ TEST(Place, Default) { EXPECT_TRUE(paddle::platform::is_gpu_place(paddle::platform::get_place())); EXPECT_TRUE(paddle::platform::is_gpu_place(paddle::platform::default_gpu())); EXPECT_TRUE(paddle::platform::is_cpu_place(paddle::platform::default_cpu())); + EXPECT_TRUE( + paddle::platform::is_mkldnn_place(paddle::platform::default_mkldnn())); paddle::platform::set_place(paddle::platform::CPUPlace()); EXPECT_TRUE(paddle::platform::is_cpu_place(paddle::platform::get_place())); + + paddle::platform::set_place(paddle::platform::MKLDNNPlace()); + EXPECT_FALSE(paddle::platform::is_cpu_place(paddle::platform::get_place())); + EXPECT_TRUE(paddle::platform::is_mkldnn_place(paddle::platform::get_place())); } TEST(Place, Print) { -- GitLab