未验证 提交 b858e17b 编写于 作者: T tensor-tang 提交者: GitHub

Merge pull request #6592 from tensor-tang/fluid

add MKLDNNPlace
...@@ -277,6 +277,14 @@ void set_constant_with_place<platform::CPUPlace>( ...@@ -277,6 +277,14 @@ void set_constant_with_place<platform::CPUPlace>(
TensorSetConstantCPU(tensor, value)); TensorSetConstantCPU(tensor, value));
} }
template <>
void set_constant_with_place<platform::MKLDNNPlace>(
const platform::DeviceContext& context, framework::Tensor* tensor,
float value) {
framework::VisitDataType(framework::ToDataType(tensor->type()),
TensorSetConstantCPU(tensor, value));
}
struct TensorSetConstantWithPlace : public boost::static_visitor<void> { struct TensorSetConstantWithPlace : public boost::static_visitor<void> {
TensorSetConstantWithPlace(const platform::DeviceContext& context, TensorSetConstantWithPlace(const platform::DeviceContext& context,
framework::Tensor* tensor, float value) framework::Tensor* tensor, float value)
......
...@@ -23,6 +23,7 @@ class PlacePrinter : public boost::static_visitor<> { ...@@ -23,6 +23,7 @@ class PlacePrinter : public boost::static_visitor<> {
public: public:
explicit PlacePrinter(std::ostream &os) : os_(os) {} explicit PlacePrinter(std::ostream &os) : os_(os) {}
void operator()(const CPUPlace &) { os_ << "CPUPlace"; } void operator()(const CPUPlace &) { os_ << "CPUPlace"; }
void operator()(const MKLDNNPlace &) { os_ << "MKLDNNPlace"; }
void operator()(const GPUPlace &p) { os_ << "GPUPlace(" << p.device << ")"; } void operator()(const GPUPlace &p) { os_ << "GPUPlace(" << p.device << ")"; }
private: private:
...@@ -38,12 +39,17 @@ const Place &get_place() { return the_default_place; } ...@@ -38,12 +39,17 @@ const Place &get_place() { return the_default_place; }
const GPUPlace default_gpu() { return GPUPlace(0); } const GPUPlace default_gpu() { return GPUPlace(0); }
const CPUPlace default_cpu() { return CPUPlace(); } const CPUPlace default_cpu() { return CPUPlace(); }
const MKLDNNPlace default_mkldnn() { return MKLDNNPlace(); }
bool is_gpu_place(const Place &p) { bool is_gpu_place(const Place &p) {
return boost::apply_visitor(IsGPUPlace(), p); return boost::apply_visitor(IsGPUPlace(), p);
} }
bool is_cpu_place(const Place &p) { bool is_cpu_place(const Place &p) {
return !boost::apply_visitor(IsGPUPlace(), p); return !is_gpu_place(p) && !is_mkldnn_place(p);
}
bool is_mkldnn_place(const Place &p) {
return boost::apply_visitor(IsMKLDNNPlace(), p);
} }
bool places_are_same_class(const Place &p1, const Place &p2) { bool places_are_same_class(const Place &p1, const Place &p2) {
......
...@@ -31,6 +31,14 @@ struct CPUPlace { ...@@ -31,6 +31,14 @@ struct CPUPlace {
inline bool operator!=(const CPUPlace &) const { return false; } inline bool operator!=(const CPUPlace &) const { return false; }
}; };
struct MKLDNNPlace {
MKLDNNPlace() {}
// needed for variant equality comparison
inline bool operator==(const MKLDNNPlace &) const { return true; }
inline bool operator!=(const MKLDNNPlace &) const { return false; }
};
struct GPUPlace { struct GPUPlace {
GPUPlace() : GPUPlace(0) {} GPUPlace() : GPUPlace(0) {}
explicit GPUPlace(int d) : device(d) {} explicit GPUPlace(int d) : device(d) {}
...@@ -50,14 +58,21 @@ struct CudnnPlace : public GPUPlace { ...@@ -50,14 +58,21 @@ struct CudnnPlace : public GPUPlace {
struct IsGPUPlace : public boost::static_visitor<bool> { struct IsGPUPlace : public boost::static_visitor<bool> {
bool operator()(const CPUPlace &) const { return false; } bool operator()(const CPUPlace &) const { return false; }
bool operator()(const MKLDNNPlace &) const { return false; }
bool operator()(const GPUPlace &gpu) const { return true; } bool operator()(const GPUPlace &gpu) const { return true; }
}; };
struct IsMKLDNNPlace : public boost::static_visitor<bool> {
bool operator()(const MKLDNNPlace &) const { return true; }
bool operator()(const CPUPlace &) const { return false; }
bool operator()(const GPUPlace &) const { return false; }
};
// Define the max number of Place in bit length. i.e., the max number of places // Define the max number of Place in bit length. i.e., the max number of places
// should be less equal than 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT) // should be less equal than 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT)
#define NUM_PLACE_TYPE_LIMIT_IN_BIT 4 #define NUM_PLACE_TYPE_LIMIT_IN_BIT 4
typedef boost::variant<CudnnPlace, GPUPlace, CPUPlace> Place; typedef boost::variant<CudnnPlace, GPUPlace, CPUPlace, MKLDNNPlace> Place;
// static check number of place types is less equal than // static check number of place types is less equal than
// 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT) // 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT)
...@@ -70,9 +85,11 @@ const Place &get_place(); ...@@ -70,9 +85,11 @@ const Place &get_place();
const GPUPlace default_gpu(); const GPUPlace default_gpu();
const CPUPlace default_cpu(); const CPUPlace default_cpu();
const MKLDNNPlace default_mkldnn();
bool is_gpu_place(const Place &); bool is_gpu_place(const Place &);
bool is_cpu_place(const Place &); bool is_cpu_place(const Place &);
bool is_mkldnn_place(const Place &);
bool places_are_same_class(const Place &, const Place &); bool places_are_same_class(const Place &, const Place &);
std::ostream &operator<<(std::ostream &, const Place &); std::ostream &operator<<(std::ostream &, const Place &);
......
...@@ -21,9 +21,15 @@ TEST(Place, Default) { ...@@ -21,9 +21,15 @@ TEST(Place, Default) {
EXPECT_TRUE(paddle::platform::is_gpu_place(paddle::platform::get_place())); EXPECT_TRUE(paddle::platform::is_gpu_place(paddle::platform::get_place()));
EXPECT_TRUE(paddle::platform::is_gpu_place(paddle::platform::default_gpu())); EXPECT_TRUE(paddle::platform::is_gpu_place(paddle::platform::default_gpu()));
EXPECT_TRUE(paddle::platform::is_cpu_place(paddle::platform::default_cpu())); EXPECT_TRUE(paddle::platform::is_cpu_place(paddle::platform::default_cpu()));
EXPECT_TRUE(
paddle::platform::is_mkldnn_place(paddle::platform::default_mkldnn()));
paddle::platform::set_place(paddle::platform::CPUPlace()); paddle::platform::set_place(paddle::platform::CPUPlace());
EXPECT_TRUE(paddle::platform::is_cpu_place(paddle::platform::get_place())); EXPECT_TRUE(paddle::platform::is_cpu_place(paddle::platform::get_place()));
paddle::platform::set_place(paddle::platform::MKLDNNPlace());
EXPECT_FALSE(paddle::platform::is_cpu_place(paddle::platform::get_place()));
EXPECT_TRUE(paddle::platform::is_mkldnn_place(paddle::platform::get_place()));
} }
TEST(Place, Print) { TEST(Place, Print) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册