diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc index 2b35e4532a9c9f72f473020d472244234af24248..a05810d7781f5286e70b53005ef0b193c945c54c 100644 --- a/paddle/operators/math/math_function.cc +++ b/paddle/operators/math/math_function.cc @@ -277,6 +277,14 @@ void set_constant_with_place( TensorSetConstantCPU(tensor, value)); } +template <> +void set_constant_with_place( + const platform::DeviceContext& context, framework::Tensor* tensor, + float value) { + framework::VisitDataType(framework::ToDataType(tensor->type()), + TensorSetConstantCPU(tensor, value)); +} + struct TensorSetConstantWithPlace : public boost::static_visitor { TensorSetConstantWithPlace(const platform::DeviceContext& context, framework::Tensor* tensor, float value) diff --git a/paddle/platform/place.cc b/paddle/platform/place.cc index 856e54df89c1c18ade040957188a2fbda0901473..25fe8d21b49b07a6afe2938245906dc1bdd90398 100644 --- a/paddle/platform/place.cc +++ b/paddle/platform/place.cc @@ -23,6 +23,7 @@ class PlacePrinter : public boost::static_visitor<> { public: explicit PlacePrinter(std::ostream &os) : os_(os) {} void operator()(const CPUPlace &) { os_ << "CPUPlace"; } + void operator()(const MKLDNNPlace &) { os_ << "MKLDNNPlace"; } void operator()(const GPUPlace &p) { os_ << "GPUPlace(" << p.device << ")"; } private: @@ -38,12 +39,17 @@ const Place &get_place() { return the_default_place; } const GPUPlace default_gpu() { return GPUPlace(0); } const CPUPlace default_cpu() { return CPUPlace(); } +const MKLDNNPlace default_mkldnn() { return MKLDNNPlace(); } bool is_gpu_place(const Place &p) { return boost::apply_visitor(IsGPUPlace(), p); } bool is_cpu_place(const Place &p) { - return !boost::apply_visitor(IsGPUPlace(), p); + return !is_gpu_place(p) && !is_mkldnn_place(p); +} + +bool is_mkldnn_place(const Place &p) { + return boost::apply_visitor(IsMKLDNNPlace(), p); } bool places_are_same_class(const Place &p1, const Place &p2) { diff --git a/paddle/platform/place.h b/paddle/platform/place.h index f0dcec8f523fb22c2dd046113b6a8f8a0b6d916d..4526945792b2ea96cc4e9df11d8f35897cba7526 100644 --- a/paddle/platform/place.h +++ b/paddle/platform/place.h @@ -31,6 +31,14 @@ struct CPUPlace { inline bool operator!=(const CPUPlace &) const { return false; } }; +struct MKLDNNPlace { + MKLDNNPlace() {} + + // needed for variant equality comparison + inline bool operator==(const MKLDNNPlace &) const { return true; } + inline bool operator!=(const MKLDNNPlace &) const { return false; } +}; + struct GPUPlace { GPUPlace() : GPUPlace(0) {} explicit GPUPlace(int d) : device(d) {} @@ -50,14 +58,21 @@ struct CudnnPlace : public GPUPlace { struct IsGPUPlace : public boost::static_visitor { bool operator()(const CPUPlace &) const { return false; } + bool operator()(const MKLDNNPlace &) const { return false; } bool operator()(const GPUPlace &gpu) const { return true; } }; +struct IsMKLDNNPlace : public boost::static_visitor { + bool operator()(const MKLDNNPlace &) const { return true; } + bool operator()(const CPUPlace &) const { return false; } + bool operator()(const GPUPlace &) const { return false; } +}; + // Define the max number of Place in bit length. i.e., the max number of places // should be less equal than 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT) #define NUM_PLACE_TYPE_LIMIT_IN_BIT 4 -typedef boost::variant Place; +typedef boost::variant Place; // static check number of place types is less equal than // 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT) @@ -70,9 +85,11 @@ const Place &get_place(); const GPUPlace default_gpu(); const CPUPlace default_cpu(); +const MKLDNNPlace default_mkldnn(); bool is_gpu_place(const Place &); bool is_cpu_place(const Place &); +bool is_mkldnn_place(const Place &); bool places_are_same_class(const Place &, const Place &); std::ostream &operator<<(std::ostream &, const Place &); diff --git a/paddle/platform/place_test.cc b/paddle/platform/place_test.cc index 33e2e5a439ce6801c02daba4bcbd462a74d7a614..184af12c230f1ccd7826e507f16f4e91ca380a45 100644 --- a/paddle/platform/place_test.cc +++ b/paddle/platform/place_test.cc @@ -21,9 +21,15 @@ TEST(Place, Default) { EXPECT_TRUE(paddle::platform::is_gpu_place(paddle::platform::get_place())); EXPECT_TRUE(paddle::platform::is_gpu_place(paddle::platform::default_gpu())); EXPECT_TRUE(paddle::platform::is_cpu_place(paddle::platform::default_cpu())); + EXPECT_TRUE( + paddle::platform::is_mkldnn_place(paddle::platform::default_mkldnn())); paddle::platform::set_place(paddle::platform::CPUPlace()); EXPECT_TRUE(paddle::platform::is_cpu_place(paddle::platform::get_place())); + + paddle::platform::set_place(paddle::platform::MKLDNNPlace()); + EXPECT_FALSE(paddle::platform::is_cpu_place(paddle::platform::get_place())); + EXPECT_TRUE(paddle::platform::is_mkldnn_place(paddle::platform::get_place())); } TEST(Place, Print) {