未验证 提交 8d4f2613 编写于 作者: C Chen Weihang 提交者: GitHub

add op count by lib method (#45680)

上级 7a92e74b
...@@ -1049,13 +1049,44 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1049,13 +1049,44 @@ All parameter, weight, gradient are variables in Paddle.
} }
return ret_values; return ret_values;
}); });
m.def("get_all_op_names", []() { m.def(
std::vector<std::string> op_names; "get_all_op_names",
for (auto &iter : OpInfoMap::Instance().map()) { [](const std::string &lib) {
op_names.emplace_back(iter.first); std::vector<std::string> op_names;
} for (auto &iter : OpInfoMap::Instance().map()) {
return op_names; op_names.emplace_back(iter.first);
}); }
if (lib == "phi") {
std::vector<std::string> ops_with_phi_kernel;
for (const auto &op_name : op_names) {
if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(
op_name)) {
ops_with_phi_kernel.emplace_back(op_name);
}
}
return ops_with_phi_kernel;
} else if (lib == "fluid") {
std::vector<std::string> ops_with_fluid_kernel;
auto all_fluid_op_kernels =
paddle::framework::OperatorWithKernel::AllOpKernels();
for (const auto &op_name : op_names) {
if (all_fluid_op_kernels.find(op_name) !=
all_fluid_op_kernels.end()) {
ops_with_fluid_kernel.emplace_back(op_name);
}
}
return ops_with_fluid_kernel;
} else {
return op_names;
}
},
py::arg("lib") = "all",
R"DOC(
Return the operator names in paddle.
Args:
lib[string]: the library contains corresponding OpKernel, could be 'phi', 'fluid' and 'all'. Default value is 'all'.
)DOC");
m.def("get_op_attrs_default_value", m.def("get_op_attrs_default_value",
[](py::bytes byte_name) -> paddle::framework::AttributeMap { [](py::bytes byte_name) -> paddle::framework::AttributeMap {
std::string op_type = byte_name; std::string op_type = byte_name;
......
...@@ -39,5 +39,19 @@ class TestGetAllRegisteredOpKernels(unittest.TestCase): ...@@ -39,5 +39,19 @@ class TestGetAllRegisteredOpKernels(unittest.TestCase):
self.assertTrue(core._get_all_register_op_kernels()['sign']) self.assertTrue(core._get_all_register_op_kernels()['sign'])
class TestGetAllOpNames(unittest.TestCase):
def test_get_all_op_names(self):
all_op_names = core.get_all_op_names()
all_op_with_phi_kernels = core.get_all_op_names("phi")
all_op_with_fluid_kernels = core.get_all_op_names("fluid")
self.assertTrue(
len(all_op_names) > len(
set(all_op_with_phi_kernels) | set(all_op_with_fluid_kernels)))
self.assertTrue("scale" in all_op_with_phi_kernels)
self.assertTrue("scale" in all_op_with_phi_kernels)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册