diff --git a/paddle/framework/op_info.h b/paddle/framework/op_info.h index e926180780609c0a8ffc6270627835c50bbce782..59a64d71371b546f76eabdeed7e7514e8fb0f84a 100644 --- a/paddle/framework/op_info.h +++ b/paddle/framework/op_info.h @@ -87,11 +87,8 @@ class OpInfoMap { } } - template - void IterAllInfo(Callback callback) { - for (auto& it : map_) { - callback(it.first, it.second); - } + const std::unordered_map& map() const { + return map_; } private: diff --git a/paddle/pybind/CMakeLists.txt b/paddle/pybind/CMakeLists.txt index 46c24e2cd53c068a25e1a5c8c6df600c3111e20a..d7cd738828a10b431370c92026b89d62add1275e 100644 --- a/paddle/pybind/CMakeLists.txt +++ b/paddle/pybind/CMakeLists.txt @@ -4,3 +4,5 @@ if(WITH_PYTHON) DEPS pybind python backward proto_desc tensor_array paddle_memory executor ${GLOB_OP_LIB}) endif(WITH_PYTHON) + +cc_binary(print_operators_doc SRCS print_operators_doc.cc DEPS ${GLOB_OP_LIB} tensor_array) diff --git a/paddle/pybind/print_operators_doc.cc b/paddle/pybind/print_operators_doc.cc new file mode 100644 index 0000000000000000000000000000000000000000..24f2a9383f7a069f1a8c7ed2bf3da46720470efa --- /dev/null +++ b/paddle/pybind/print_operators_doc.cc @@ -0,0 +1,132 @@ +#include +#include // std::stringstream +#include + +#include "paddle/framework/op_info.h" +#include "paddle/framework/op_registry.h" +#include "paddle/pybind/pybind.h" + +std::string Escape(const std::string& s) { + std::string r; + for (size_t i = 0; i < s.size(); i++) { + switch (s[i]) { + case '\"': + r += "\\\""; + break; + case '\\': + r += "\\\\"; + break; + case '\n': + r += "\\n"; + break; + case '\t': + r += "\\t"; + case '\r': + break; + default: + r += s[i]; + break; + } + } + return r; +} + +std::string AttrType(paddle::framework::AttrType at) { + switch (at) { + case paddle::framework::INT: + return "int"; + case paddle::framework::FLOAT: + return "float"; + case paddle::framework::STRING: + return "string"; + case paddle::framework::BOOLEAN: + return "bool"; + case paddle::framework::INTS: + return "int array"; + case paddle::framework::FLOATS: + return "float array"; + case paddle::framework::STRINGS: + return "string array"; + case paddle::framework::BOOLEANS: + return "bool array"; + case paddle::framework::BLOCK: + return "block id"; + } + return "UNKNOWN"; // not possible +} + +void PrintVar(const paddle::framework::OpProto::Var& v, std::stringstream& ss) { + ss << " { " + << "\n" + << " \"name\" : \"" << Escape(v.name()) << "\",\n" + << " \"comment\" : \"" << Escape(v.comment()) << "\",\n" + << " \"duplicable\" : " << v.duplicable() << ",\n" + << " \"intermediate\" : " << v.intermediate() << "\n" + << " },"; +} + +void PrintAttr(const paddle::framework::OpProto::Attr& a, + std::stringstream& ss) { + ss << " { " + << "\n" + << " \"name\" : \"" << Escape(a.name()) << "\",\n" + << " \"type\" : \"" << AttrType(a.type()) << "\",\n" + << " \"comment\" : \"" << Escape(a.comment()) << "\",\n" + << " \"generated\" : " << a.generated() << "\n" + << " },"; +} + +void PrintOpProto(const std::string& type, + const paddle::framework::OpInfo& opinfo, + std::stringstream& ss) { + std::cerr << "Processing " << type << "\n"; + + const paddle::framework::OpProto* p = opinfo.proto_; + if (p == nullptr) { + return; // It is possible that an operator doesn't have OpProto. + } + + ss << "{\n" + << " \"type\" : \"" << Escape(p->type()) << "\",\n" + << " \"comment\" : \"" << Escape(p->comment()) << "\",\n"; + + ss << " \"inputs\" : [ " + << "\n"; + for (int i = 0; i < p->inputs_size(); i++) { + PrintVar(p->inputs(i), ss); + } + ss.seekp(-1, ss.cur); // remove the trailing comma + ss << " ], " + << "\n"; + + ss << " \"outputs\" : [ " + << "\n"; + for (int i = 0; i < p->outputs_size(); i++) { + PrintVar(p->outputs(i), ss); + } + ss.seekp(-1, ss.cur); // remove the trailing comma + ss << " ], " + << "\n"; + + ss << " \"attrs\" : [ " + << "\n"; + for (int i = 0; i < p->attrs_size(); i++) { + PrintAttr(p->attrs(i), ss); + } + ss.seekp(-1, ss.cur); // remove the trailing comma + ss << " ] " + << "\n"; + + ss << "},"; +} + +int main() { + std::stringstream ss; + ss << "[\n"; + for (auto& iter : paddle::framework::OpInfoMap::Instance().map()) { + PrintOpProto(iter.first, iter.second, ss); + } + ss.seekp(-1, ss.cur); // remove the trailing comma + ss << "]\n"; + std::cout << ss.str(); +} diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 26b793a4bbf5df7a2635838a6c6a8264ca8ebb67..b6e44fdbad6e2817e3077901f58177adc4bb0c71 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -225,15 +225,16 @@ All parameter, weight, gradient are variables in Paddle. //! Python str. If you want a str object, you should cast them in Python. m.def("get_all_op_protos", []() -> std::vector { std::vector ret_values; - - OpInfoMap::Instance().IterAllInfo([&ret_values](const std::string &type, - const OpInfo &info) { - if (!info.HasOpProtoAndChecker()) return; - std::string str; - PADDLE_ENFORCE(info.Proto().SerializeToString(&str), - "Serialize OpProto Error. This could be a bug of Paddle."); - ret_values.emplace_back(str); - }); + for (auto &iter : OpInfoMap::Instance().map()) { + auto &info = iter.second; + if (info.HasOpProtoAndChecker()) { + std::string str; + PADDLE_ENFORCE( + info.Proto().SerializeToString(&str), + "Serialize OpProto Error. This could be a bug of Paddle."); + ret_values.emplace_back(str); + } + } return ret_values; }); m.def_submodule(