提交 0ac47f2f 编写于 作者: J jiangjinsheng

fixed ScalarCast

上级 cbdc59e8
...@@ -15,6 +15,9 @@ ...@@ -15,6 +15,9 @@
"""inner_ops""" """inner_ops"""
import numbers
from ..._checkparam import Validator as validator
from ..._checkparam import Rel
from ...common.dtype import tensor, dtype_to_pytype from ...common.dtype import tensor, dtype_to_pytype
from ..primitive import prim_attr_register, PrimitiveWithInfer from ..primitive import prim_attr_register, PrimitiveWithInfer
...@@ -40,8 +43,10 @@ class ScalarCast(PrimitiveWithInfer): ...@@ -40,8 +43,10 @@ class ScalarCast(PrimitiveWithInfer):
pass pass
def __infer__(self, x, t): def __infer__(self, x, t):
validator.check_integer('x shape', len(x['shape']), 0, Rel.EQ, self.name)
value, to = x['value'], t['value'] value, to = x['value'], t['value']
if value is not None: if value is not None:
validator.check_value_type("value", value, [numbers.Number, bool], self.name)
if isinstance(to, type(tensor)): if isinstance(to, type(tensor)):
to = to.element_type() to = to.element_type()
np_type = dtype_to_pytype(to) np_type = dtype_to_pytype(to)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册