未验证 提交 8d4f2613 编写于 作者: C Chen Weihang 提交者: GitHub

add op count by lib method (#45680)

上级 7a92e74b
......@@ -1049,13 +1049,44 @@ All parameter, weight, gradient are variables in Paddle.
}
return ret_values;
});
m.def("get_all_op_names", []() {
std::vector<std::string> op_names;
for (auto &iter : OpInfoMap::Instance().map()) {
op_names.emplace_back(iter.first);
}
return op_names;
});
m.def(
"get_all_op_names",
[](const std::string &lib) {
std::vector<std::string> op_names;
for (auto &iter : OpInfoMap::Instance().map()) {
op_names.emplace_back(iter.first);
}
if (lib == "phi") {
std::vector<std::string> ops_with_phi_kernel;
for (const auto &op_name : op_names) {
if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(
op_name)) {
ops_with_phi_kernel.emplace_back(op_name);
}
}
return ops_with_phi_kernel;
} else if (lib == "fluid") {
std::vector<std::string> ops_with_fluid_kernel;
auto all_fluid_op_kernels =
paddle::framework::OperatorWithKernel::AllOpKernels();
for (const auto &op_name : op_names) {
if (all_fluid_op_kernels.find(op_name) !=
all_fluid_op_kernels.end()) {
ops_with_fluid_kernel.emplace_back(op_name);
}
}
return ops_with_fluid_kernel;
} else {
return op_names;
}
},
py::arg("lib") = "all",
R"DOC(
Return the operator names in paddle.
Args:
lib[string]: the library contains corresponding OpKernel, could be 'phi', 'fluid' and 'all'. Default value is 'all'.
)DOC");
m.def("get_op_attrs_default_value",
[](py::bytes byte_name) -> paddle::framework::AttributeMap {
std::string op_type = byte_name;
......
......@@ -39,5 +39,19 @@ class TestGetAllRegisteredOpKernels(unittest.TestCase):
self.assertTrue(core._get_all_register_op_kernels()['sign'])
class TestGetAllOpNames(unittest.TestCase):
def test_get_all_op_names(self):
all_op_names = core.get_all_op_names()
all_op_with_phi_kernels = core.get_all_op_names("phi")
all_op_with_fluid_kernels = core.get_all_op_names("fluid")
self.assertTrue(
len(all_op_names) > len(
set(all_op_with_phi_kernels) | set(all_op_with_fluid_kernels)))
self.assertTrue("scale" in all_op_with_phi_kernels)
self.assertTrue("scale" in all_op_with_phi_kernels)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册