提交 b27022dd 编写于 作者: M Manjunath Kudlur 提交者: TensorFlower Gardener

Fix GetOpList and GetPythonWrappers SWIG wrappers for Python 3.

- Return the result from GetOpList as uninterpreted bytes object.
- Write a input typemap for GetPythonWrappers to receive python 'bytes'
object and convert to const char* pointer and length.
Change: 117258253
上级 725e968a
......@@ -237,11 +237,7 @@ tensorflow::ImportNumpy();
// is not expected to be NULL-terminated, and TF_Buffer.length does not count
// the terminator.
%typemap(out) TF_Buffer (TF_GetOpList,TF_GetBuffer) {
%#if PY_MAJOR_VERSION < 3
$result = PyString_FromStringAndSize(
%#else
$result = PyUnicode_FromStringAndSize(
%#endif
$result = PyBytes_FromStringAndSize(
reinterpret_cast<const char*>($1.data), $1.length);
}
......
......@@ -43,8 +43,6 @@ def load_op_library(library_filename):
Pass "library_filename" to a platform-specific mechanism for dynamically
loading a library. The rules for determining the exact location of the
library are platform-specific and are not documented here.
Expects the symbols "RegisterOps", "RegisterKernels", and "GetOpList", to be
defined in the library.
Args:
library_filename: Path to the plugin.
......@@ -78,7 +76,7 @@ def load_op_library(library_filename):
op_list_str = py_tf.TF_GetOpList(lib_handle)
op_list = op_def_pb2.OpList()
op_list.ParseFromString(compat.as_bytes(op_list_str))
wrappers = py_tf.GetPythonWrappers(op_list_str, len(op_list_str))
wrappers = py_tf.GetPythonWrappers(op_list_str)
# Get a unique name for the module.
module_name = hashlib.md5(wrappers).hexdigest()
......
......@@ -693,8 +693,8 @@ string GetAllPythonOps(const char* hidden, bool require_shapes) {
return GetPythonOps(ops, hidden, require_shapes);
}
string GetPythonWrappers(const char* buf, size_t len) {
string op_list_str(buf, len);
string GetPythonWrappers(const char* op_wrapper_buf, size_t op_wrapper_len) {
string op_list_str(op_wrapper_buf, op_wrapper_len);
OpList ops;
ops.ParseFromString(op_list_str);
return GetPythonOps(ops, "", false);
......
......@@ -34,7 +34,7 @@ string GetPythonOps(const OpList& ops, const string& hidden_ops,
// Get the python wrappers for a list of ops in a OpList.
// buf should be a pointer to a buffer containing the binary encoded OpList
// proto, and len should be the length of that buffer.
string GetPythonWrappers(const char* buf, size_t len);
string GetPythonWrappers(const char* op_wrapper_buf, size_t op_wrapper_len);
} // namespace tensorflow
......
......@@ -19,6 +19,23 @@ limitations under the License.
#include "tensorflow/python/framework/python_op_gen.h"
%}
// Input typemap for GetPythonWrappers.
// Accepts a python object of 'bytes' type, and converts it to
// a const char* pointer and size_t length. The default typemap
// going from python bytes to const char* tries to decode the
// contents from utf-8 to unicode for Python version >= 3, but
// we want the bytes to be uninterpreted.
%typemap(in) (const char* op_wrapper_buf, size_t op_wrapper_len) {
char* c_string;
Py_ssize_t py_size;
if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) {
SWIG_fail;
}
$1 = c_string;
$2 = static_cast<size_t>(py_size);
}
%ignoreall;
%unignore tensorflow::GetPythonWrappers;
%include "tensorflow/python/framework/python_op_gen.h"
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册