提交 f20edba9 编写于 作者: V Vadim Levin

fix: conditionally define generic NumPy NDArray alias

上级 fe4f5b53
......@@ -15,7 +15,8 @@ from .nodes import (ASTNode, ASTNodeType, NamespaceNode, ClassNode, FunctionNode
EnumerationNode, ConstantNode)
from .nodes.type_node import (TypeNode, AliasTypeNode, AliasRefTypeNode,
AggregatedTypeNode, ASTNodeTypeNode)
AggregatedTypeNode, ASTNodeTypeNode,
ConditionalAliasTypeNode, PrimitiveTypeNode)
def generate_typing_stubs(root: NamespaceNode, output_path: Path):
......@@ -682,28 +683,37 @@ def _generate_typing_module(root: NamespaceNode, output_path: Path) -> None:
f"Provided type node '{type_node.ctype_name}' is not an aggregated type"
for item in filter(lambda i: isinstance(i, AliasRefTypeNode), type_node):
register_alias(PREDEFINED_TYPES[item.ctype_name]) # type: ignore
type_node = PREDEFINED_TYPES[item.ctype_name]
if isinstance(type_node, AliasTypeNode):
register_alias(type_node)
elif isinstance(type_node, ConditionalAliasTypeNode):
conditional_type_nodes[type_node.ctype_name] = type_node
def create_alias_for_enum_node(enum_node: ASTNode) -> AliasTypeNode:
"""Create int alias corresponding to the given enum node.
def create_alias_for_enum_node(enum_node_alias: AliasTypeNode) -> ConditionalAliasTypeNode:
"""Create conditional int alias corresponding to the given enum node.
Args:
enum_node (ASTNodeTypeNode): Enumeration node to create int alias for.
enum_node (AliasTypeNode): Enumeration node to create conditional
int alias for.
Returns:
AliasTypeNode: int alias node with same export name as enum.
ConditionalAliasTypeNode: conditional int alias node with same
export name as enum.
"""
enum_node = enum_node_alias.ast_node
assert enum_node.node_type == ASTNodeType.Enumeration, \
f"{enum_node} has wrong node type. Expected type: Enumeration."
enum_export_name, enum_module_name = get_enum_module_and_export_name(
enum_node
)
enum_full_export_name = f"{enum_module_name}.{enum_export_name}"
alias_node = AliasTypeNode.int_(enum_full_export_name,
enum_export_name)
type_checking_time_definitions.add(alias_node)
return alias_node
return ConditionalAliasTypeNode(
enum_export_name,
"typing.TYPE_CHECKING",
positive_branch_type=enum_node_alias,
negative_branch_type=PrimitiveTypeNode.int_(enum_export_name),
condition_required_imports=("import typing", )
)
def register_alias(alias_node: AliasTypeNode) -> None:
typename = alias_node.typename
......@@ -726,11 +736,15 @@ def _generate_typing_module(root: NamespaceNode, output_path: Path) -> None:
continue
if item.ast_node.node_type != ASTNodeType.Enumeration:
continue
alias_node.value.items[i] = create_alias_for_enum_node(item.ast_node)
enum_node = create_alias_for_enum_node(item)
alias_node.value.items[i] = enum_node
conditional_type_nodes[enum_node.ctype_name] = enum_node
if isinstance(alias_node.value, ASTNodeTypeNode) \
and alias_node.value.ast_node == ASTNodeType.Enumeration:
alias_node.value = create_alias_for_enum_node(alias_node.ast_node)
enum_node = create_alias_for_enum_node(alias_node.ast_node)
conditional_type_nodes[enum_node.ctype_name] = enum_node
return
# Strip module prefix from aliased types
aliases[typename] = alias_node.value.full_typename.replace(
......@@ -744,7 +758,7 @@ def _generate_typing_module(root: NamespaceNode, output_path: Path) -> None:
required_imports: Set[str] = set()
aliases: Dict[str, str] = {}
type_checking_time_definitions: Set[AliasTypeNode] = set()
conditional_type_nodes: Dict[str, ConditionalAliasTypeNode] = {}
# Resolve each node and register aliases
TypeNode.compatible_to_runtime_usage = True
......@@ -752,6 +766,12 @@ def _generate_typing_module(root: NamespaceNode, output_path: Path) -> None:
node.resolve(root)
if isinstance(node, AliasTypeNode):
register_alias(node)
elif isinstance(node, ConditionalAliasTypeNode):
conditional_type_nodes[node.ctype_name] = node
for node in conditional_type_nodes.values():
for required_import in node.required_definition_imports:
required_imports.add(required_import)
output_stream = StringIO()
output_stream.write("__all__ = [\n")
......@@ -762,12 +782,10 @@ def _generate_typing_module(root: NamespaceNode, output_path: Path) -> None:
_write_required_imports(required_imports, output_stream)
# Add type checking time definitions as generated __init__.py content
for alias in type_checking_time_definitions:
output_stream.write("if typing.TYPE_CHECKING:\n ")
output_stream.write(f"{alias.typename} = {alias.ctype_name}\nelse:\n")
output_stream.write(f" {alias.typename} = {alias.value.ctype_name}\n")
if type_checking_time_definitions:
output_stream.write("\n\n")
for _, type_node in conditional_type_nodes.items():
output_stream.write(f"if {type_node.condition}:\n ")
output_stream.write(f"{type_node.typename} = {type_node.positive_branch_type.full_typename}\nelse:\n")
output_stream.write(f" {type_node.typename} = {type_node.negative_branch_type.full_typename}\n\n\n")
for alias_name, alias_type in aliases.items():
output_stream.write(f"{alias_name} = {alias_type}\n")
......
......@@ -307,14 +307,31 @@ class AliasTypeNode(TypeNode):
return cls(ctype_name, PrimitiveTypeNode.float_(), export_name, doc)
@classmethod
def array_(cls, ctype_name: str, shape: Optional[Tuple[int, ...]],
dtype: Optional[str] = None, export_name: Optional[str] = None,
doc: Optional[str] = None):
def array_ref_(cls, ctype_name: str, array_ref_name: str,
shape: Optional[Tuple[int, ...]],
dtype: Optional[str] = None,
export_name: Optional[str] = None,
doc: Optional[str] = None):
"""Create alias to array reference alias `array_ref_name`.
This is required to preserve backward compatibility with Python < 3.9
and NumPy 1.20, when NumPy module introduces generics support.
Args:
ctype_name (str): Name of the alias.
array_ref_name (str): Name of the conditional array alias.
shape (Optional[Tuple[int, ...]]): Array shape.
dtype (Optional[str], optional): Array type. Defaults to None.
export_name (Optional[str], optional): Alias export name.
Defaults to None.
doc (Optional[str], optional): Documentation string for alias.
Defaults to None.
"""
if doc is None:
doc = "Shape: " + str(shape)
doc = f"NDArray(shape={shape}, dtype={dtype})"
else:
doc += ". Shape: " + str(shape)
return cls(ctype_name, NDArrayTypeNode(ctype_name, shape, dtype),
doc += f". NDArray(shape={shape}, dtype={dtype})"
return cls(ctype_name, AliasRefTypeNode(array_ref_name),
export_name, doc)
@classmethod
......@@ -376,23 +393,111 @@ class AliasTypeNode(TypeNode):
export_name, doc)
class ConditionalAliasTypeNode(TypeNode):
"""Type node representing an alias protected by condition checked in runtime.
Example:
```python
if numpy.lib.NumpyVersion(numpy.__version__) > "1.20.0" and sys.version_info >= (3, 9)
NumPyArray = numpy.ndarray[typing.Any, numpy.dtype[numpy.generic]]
else:
NumPyArray = numpy.ndarray
```
is defined as follows:
```python
ConditionalAliasTypeNode(
"NumPyArray",
'numpy.lib.NumpyVersion(numpy.__version__) > "1.20.0" and sys.version_info >= (3, 9)',
NDArrayTypeNode("NumPyArray"),
NDArrayTypeNode("NumPyArray", use_numpy_generics=False),
condition_required_imports=("import numpy", "import sys")
)
```
"""
def __init__(self, ctype_name: str, condition: str,
positive_branch_type: TypeNode,
negative_branch_type: TypeNode,
export_name: Optional[str] = None,
condition_required_imports: Sequence[str] = ()) -> None:
super().__init__(ctype_name)
self.condition = condition
self.positive_branch_type = positive_branch_type
self.positive_branch_type.ctype_name = self.ctype_name
self.negative_branch_type = negative_branch_type
self.negative_branch_type.ctype_name = self.ctype_name
self._export_name = export_name
self._condition_required_imports = condition_required_imports
@property
def typename(self) -> str:
if self._export_name is not None:
return self._export_name
return self.ctype_name
@property
def full_typename(self) -> str:
return "cv2.typing." + self.typename
@property
def required_definition_imports(self) -> Generator[str, None, None]:
yield from self.positive_branch_type.required_usage_imports
yield from self.negative_branch_type.required_usage_imports
yield from self._condition_required_imports
@property
def required_usage_imports(self) -> Generator[str, None, None]:
yield "import cv2.typing"
@property
def is_resolved(self) -> bool:
return self.positive_branch_type.is_resolved \
and self.negative_branch_type.is_resolved
def resolve(self, root: ASTNode):
try:
self.positive_branch_type.resolve(root)
self.negative_branch_type.resolve(root)
except TypeResolutionError as e:
raise TypeResolutionError(
'Failed to resolve alias "{}" exposed as "{}"'.format(
self.ctype_name, self.typename
)
) from e
@classmethod
def numpy_array_(cls, ctype_name: str, export_name: Optional[str] = None,
shape: Optional[Tuple[int, ...]] = None,
dtype: Optional[str] = None):
return cls(
ctype_name,
('numpy.lib.NumpyVersion(numpy.__version__) > "1.20.0" '
'and sys.version_info >= (3, 9)'),
NDArrayTypeNode(ctype_name, shape, dtype),
NDArrayTypeNode(ctype_name, shape, dtype,
use_numpy_generics=False),
condition_required_imports=("import numpy", "import sys")
)
class NDArrayTypeNode(TypeNode):
"""Type node representing NumPy ndarray.
"""
def __init__(self, ctype_name: str, shape: Optional[Tuple[int, ...]] = None,
dtype: Optional[str] = None) -> None:
def __init__(self, ctype_name: str,
shape: Optional[Tuple[int, ...]] = None,
dtype: Optional[str] = None,
use_numpy_generics: bool = True) -> None:
super().__init__(ctype_name)
self.shape = shape
self.dtype = dtype
self._use_numpy_generics = use_numpy_generics
@property
def typename(self) -> str:
return "numpy.ndarray[{shape}, numpy.dtype[{dtype}]]".format(
if self._use_numpy_generics:
# NOTE: Shape is not fully supported yet
# shape=self.shape if self.shape is not None else "typing.Any",
shape="typing.Any",
dtype=self.dtype if self.dtype is not None else "numpy.generic"
)
dtype = self.dtype if self.dtype is not None else "numpy.generic"
return f"numpy.ndarray[typing.Any, numpy.dtype[{dtype}]]"
return "numpy.ndarray"
@property
def required_usage_imports(self) -> Generator[str, None, None]:
......
from .nodes.type_node import (
AliasTypeNode, AliasRefTypeNode, PrimitiveTypeNode,
ASTNodeTypeNode, NDArrayTypeNode, NoneTypeNode, SequenceTypeNode,
TupleTypeNode, UnionTypeNode, AnyTypeNode
TupleTypeNode, UnionTypeNode, AnyTypeNode, ConditionalAliasTypeNode
)
# Set of predefined types used to cover cases when library doesn't
......@@ -30,12 +30,15 @@ _PREDEFINED_TYPES = (
PrimitiveTypeNode.str_("char"),
PrimitiveTypeNode.str_("String"),
PrimitiveTypeNode.str_("c_string"),
ConditionalAliasTypeNode.numpy_array_("NumPyArrayGeneric"),
ConditionalAliasTypeNode.numpy_array_("NumPyArrayFloat32", dtype="numpy.float32"),
ConditionalAliasTypeNode.numpy_array_("NumPyArrayFloat64", dtype="numpy.float64"),
NoneTypeNode("void"),
AliasTypeNode.int_("void*", "IntPointer", "Represents an arbitrary pointer"),
AliasTypeNode.union_(
"Mat",
items=(ASTNodeTypeNode("Mat", module_name="cv2.mat_wrapper"),
NDArrayTypeNode("Mat")),
AliasRefTypeNode("NumPyArrayGeneric")),
export_name="MatLike"
),
AliasTypeNode.sequence_("MatShape", PrimitiveTypeNode.int_()),
......@@ -137,10 +140,22 @@ _PREDEFINED_TYPES = (
ASTNodeTypeNode("gapi.wip.draw.Mosaic"),
ASTNodeTypeNode("gapi.wip.draw.Poly"))),
SequenceTypeNode("Prims", AliasRefTypeNode("Prim")),
AliasTypeNode.array_("Matx33f", (3, 3), "numpy.float32"),
AliasTypeNode.array_("Matx33d", (3, 3), "numpy.float64"),
AliasTypeNode.array_("Matx44f", (4, 4), "numpy.float32"),
AliasTypeNode.array_("Matx44d", (4, 4), "numpy.float64"),
AliasTypeNode.array_ref_("Matx33f",
array_ref_name="NumPyArrayFloat32",
shape=(3, 3),
dtype="numpy.float32"),
AliasTypeNode.array_ref_("Matx33d",
array_ref_name="NumPyArrayFloat64",
shape=(3, 3),
dtype="numpy.float64"),
AliasTypeNode.array_ref_("Matx44f",
array_ref_name="NumPyArrayFloat32",
shape=(4, 4),
dtype="numpy.float32"),
AliasTypeNode.array_ref_("Matx44d",
array_ref_name="NumPyArrayFloat64",
shape=(4, 4),
dtype="numpy.float64"),
NDArrayTypeNode("vector<uchar>", dtype="numpy.uint8"),
NDArrayTypeNode("vector_uchar", dtype="numpy.uint8"),
TupleTypeNode("GMat2", items=(ASTNodeTypeNode("GMat"),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册