提交 86437a8d 编写于 作者: Y Yu Yang 提交者: GitHub

Global function, op_support_gpu (#4980)

上级 d2f3c8bb
......@@ -252,5 +252,20 @@ std::ostream& operator<<(std::ostream& os,
return os;
}
bool OpSupportGPU(const std::string& op_type) {
auto& all_kernels = OperatorWithKernel::AllOpKernels();
auto it = all_kernels.find(op_type);
if (it == all_kernels.end()) {
// All control operator must support GPU
return true;
}
for (auto& kern_pair : it->second) {
if (platform::is_gpu_place(kern_pair.first.place_)) {
return true;
}
}
return false;
}
} // namespace framework
} // namespace paddle
......@@ -649,5 +649,7 @@ class OperatorWithKernel : public OperatorBase {
std::ostream& operator<<(std::ostream& os,
const OperatorWithKernel::OpKernelKey& kernel_key);
extern bool OpSupportGPU(const std::string& op_type);
} // namespace framework
} // namespace paddle
......@@ -466,6 +466,8 @@ All parameter, weight, gradient are variables in Paddle.
BindVarDsec(m);
BindOpDesc(m);
m.def("op_support_gpu", OpSupportGPU);
return m.ptr();
}
} // namespace pybind
......
import unittest
import paddle.v2.framework.core as core
class TestOpSupportGPU(unittest.TestCase):
def test_case(self):
self.assertEqual(core.is_compile_gpu(), core.op_support_gpu("sum"))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册