未验证 提交 c09738e3 编写于 作者: M Mihai Maruseac 提交者: GitHub

Merge pull request #58564 from tensorflow/r2.10-9f03a9d3

r2.10 cherry-pick: 9f03a9d3 "Replace CHECK with returning an InternalError on failing to create python tuple"
......@@ -83,8 +83,8 @@ bool IsCPUDevice(const Device* d) {
return d == nullptr || d->tensorflow_accelerator_device_info() == nullptr;
}
// Givens the 'call', prepares the token and inputs as a python tuple
// that is appropriate for calling the trampoline.
// Given the 'call', prepares the token and inputs as a python tuple that is
// appropriate for calling the trampoline.
Status MakeArgTuple(const PyCall* call, TFE_Context* ctx, PyObject** tuple) {
int64_t n = call->ins.size();
PyObject* lst = PyList_New(n);
......@@ -119,7 +119,11 @@ Status MakeArgTuple(const PyCall* call, TFE_Context* ctx, PyObject** tuple) {
PyList_SetItem(lst, i, arg);
}
*tuple = Py_BuildValue("(ssN)", call->token.c_str(), device_name, lst);
CHECK(*tuple);
if (*tuple == nullptr) {
return errors::Internal(
"Failed to create python tuple. Please make sure `token` is a "
"well-formed UTF-8 string.");
}
return OkStatus();
}
......
......@@ -17,7 +17,9 @@
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.ops import gen_script_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import script_ops
from tensorflow.python.ops.script_ops import numpy_function
......@@ -103,6 +105,15 @@ class PyFunctionTest(test.TestCase):
expect_result = constant_op.constant(3, dtypes.int32)
self.assertAllEqual(actual_result, expect_result)
@test_util.run_in_graph_and_eager_modes
def test_fail_on_non_utf8_token(self):
value = constant_op.constant(value=[1, 2])
token = b"\xb0"
data_type = [dtypes.int32]
with self.assertRaises((errors.InternalError, UnicodeDecodeError)):
self.evaluate(
gen_script_ops.py_func(input=[value], token=token, Tout=data_type))
if __name__ == "__main__":
test.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册