未验证 提交 ee56906e 编写于 作者: L Leo Chen 提交者: GitHub

fit for printing cinn_launch op (#42141)

* fit for printing cinn_launch op

* update boost::variant caster for bytes
上级 18e9aafb
......@@ -1930,6 +1930,17 @@ All parameter, weight, gradient are variables in Paddle.
which contains the id pair of pruned block and corresponding
origin block.
)DOC");
m.def("get_readable_comile_key", [](const OpDesc &op_desc) {
auto compilation_key =
BOOST_GET_CONST(std::string, op_desc.GetAttr("compilation_key"));
VLOG(4) << std::hash<std::string>{}(compilation_key) << " "
<< compilation_key.size();
proto::ProgramDesc desc;
desc.ParseFromString(compilation_key);
auto s = desc.DebugString();
VLOG(4) << s;
return s;
});
m.def("empty_var_name",
[]() { return std::string(framework::kEmptyVarName); });
m.def("grad_var_suffix",
......
......@@ -45,10 +45,28 @@ struct PYBIND11_HIDDEN paddle_variant_caster_visitor
paddle_variant_caster_visitor(return_value_policy policy, handle parent)
: policy(policy), parent(parent) {}
template <class T>
handle operator()(T const &src) const {
template <class T,
typename std::enable_if<!std::is_same<T, std::string>::value,
bool>::type* = nullptr>
handle operator()(T const& src) const {
return make_caster<T>::cast(src, policy, parent);
}
template <class T,
typename std::enable_if<std::is_same<T, std::string>::value,
bool>::type* = nullptr>
handle operator()(T const& src) const {
try {
return make_caster<T>::cast(src, policy, parent);
} catch (std::exception& ex) {
VLOG(4) << ex.what();
VLOG(4) << src;
// UnicodeDecodeError, src is not utf-8 encoded
// see details:
// https://github.com/pybind/pybind11/blob/master/docs/advanced/cast/strings.rst
return PYBIND11_BYTES_FROM_STRING_AND_SIZE(src.data(), src.size());
}
}
};
template <class Variant>
......@@ -105,7 +123,7 @@ struct paddle_variant_caster<V<Ts...>> {
return load_success_;
}
static handle cast(Type const &src, return_value_policy policy,
static handle cast(Type const& src, return_value_policy policy,
handle parent) {
paddle_variant_caster_visitor visitor(policy, parent);
return boost::apply_visitor(visitor, src);
......
......@@ -2863,8 +2863,22 @@ class Operator(object):
attrs_str += ", "
continue
# it is bytes of serialized protobuf
if self.type == 'cinn_launch' and name == 'compilation_key':
# value = core.get_readable_comile_key(self.desc)
v = self.desc.attr(name)
prog = Program()
prog = prog.parse_from_string(v)
s = prog._to_readable_code()
lines = s.split('\n')
value = '\n'.join([' ' + line for line in lines])
value = '\n' + value
else:
value = self.desc.attr(name)
a = "{name} = {value}".format(
name=name, type=attr_type, value=self.desc.attr(name))
name=name, type=attr_type, value=value)
attrs_str += a
if i != len(attr_names) - 1:
attrs_str += ", "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册