diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 2fca816f353635d3bff184323755961ee82fbb68..a67625fa88fd2fbe4db43241ee824519ceac7017 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -252,5 +252,20 @@ std::ostream& operator<<(std::ostream& os, return os; } +bool OpSupportGPU(const std::string& op_type) { + auto& all_kernels = OperatorWithKernel::AllOpKernels(); + auto it = all_kernels.find(op_type); + if (it == all_kernels.end()) { + // All control operator must support GPU + return true; + } + for (auto& kern_pair : it->second) { + if (platform::is_gpu_place(kern_pair.first.place_)) { + return true; + } + } + return false; +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 12cd307297d010201a47e048089ed7be0db52647..9d7fe1f5ba293227e67cf6bfcd566a1247c567ed 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -649,5 +649,7 @@ class OperatorWithKernel : public OperatorBase { std::ostream& operator<<(std::ostream& os, const OperatorWithKernel::OpKernelKey& kernel_key); +extern bool OpSupportGPU(const std::string& op_type); + } // namespace framework } // namespace paddle diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index e5ddc14587623905dbf52b4c1690236ffeb069a1..26b793a4bbf5df7a2635838a6c6a8264ca8ebb67 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -466,6 +466,8 @@ All parameter, weight, gradient are variables in Paddle. BindVarDsec(m); BindOpDesc(m); + m.def("op_support_gpu", OpSupportGPU); + return m.ptr(); } } // namespace pybind diff --git a/python/paddle/v2/framework/tests/test_op_support_gpu.py b/python/paddle/v2/framework/tests/test_op_support_gpu.py new file mode 100644 index 0000000000000000000000000000000000000000..dd36c666c440a5c378dfceac4502cd8277417412 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_op_support_gpu.py @@ -0,0 +1,11 @@ +import unittest +import paddle.v2.framework.core as core + + +class TestOpSupportGPU(unittest.TestCase): + def test_case(self): + self.assertEqual(core.is_compile_gpu(), core.op_support_gpu("sum")) + + +if __name__ == '__main__': + unittest.main()