提交 d5ce6172 编写于 作者: W Wenchen Fan 提交者: Yin Huai

[SPARK-13740][SQL] add null check for _verify_type in types.py

## What changes were proposed in this pull request?

This PR adds null check in `_verify_type` according to the nullability information.

## How was this patch tested?

new doc tests

Author: Wenchen Fan <wenchen@databricks.com>

Closes #11574 from cloud-fan/py-null-check.
上级 9740954f
......@@ -1091,7 +1091,7 @@ _acceptable_types = {
}
def _verify_type(obj, dataType):
def _verify_type(obj, dataType, nullable=True):
"""
Verify the type of obj against dataType, raise a TypeError if they do not match.
......@@ -1120,10 +1120,29 @@ def _verify_type(obj, dataType):
Traceback (most recent call last):
...
ValueError:...
>>> _verify_type(None, ByteType(), False) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValueError:...
>>> _verify_type([1, None], ArrayType(ShortType(), False)) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValueError:...
>>> _verify_type({None: 1}, MapType(StringType(), IntegerType()))
Traceback (most recent call last):
...
ValueError:...
>>> schema = StructType().add("a", IntegerType()).add("b", StringType(), False)
>>> _verify_type((1, None), schema) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValueError:...
"""
# all objects are nullable
if obj is None:
return
if nullable:
return
else:
raise ValueError("This field is not nullable, but got None")
# StringType can work with any types
if isinstance(dataType, StringType):
......@@ -1160,19 +1179,19 @@ def _verify_type(obj, dataType):
elif isinstance(dataType, ArrayType):
for i in obj:
_verify_type(i, dataType.elementType)
_verify_type(i, dataType.elementType, dataType.containsNull)
elif isinstance(dataType, MapType):
for k, v in obj.items():
_verify_type(k, dataType.keyType)
_verify_type(v, dataType.valueType)
_verify_type(k, dataType.keyType, False)
_verify_type(v, dataType.valueType, dataType.valueContainsNull)
elif isinstance(dataType, StructType):
if len(obj) != len(dataType.fields):
raise ValueError("Length of object (%d) does not match with "
"length of fields (%d)" % (len(obj), len(dataType.fields)))
for v, f in zip(obj, dataType.fields):
_verify_type(v, f.dataType)
_verify_type(v, f.dataType, f.nullable)
# This is used to unpickle a Row from JVM
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册