提交 0ac47f2f 编写于 作者: J jiangjinsheng

fixed ScalarCast

上级 cbdc59e8
......@@ -15,6 +15,9 @@
"""inner_ops"""
import numbers
from ..._checkparam import Validator as validator
from ..._checkparam import Rel
from ...common.dtype import tensor, dtype_to_pytype
from ..primitive import prim_attr_register, PrimitiveWithInfer
......@@ -40,8 +43,10 @@ class ScalarCast(PrimitiveWithInfer):
pass
def __infer__(self, x, t):
validator.check_integer('x shape', len(x['shape']), 0, Rel.EQ, self.name)
value, to = x['value'], t['value']
if value is not None:
validator.check_value_type("value", value, [numbers.Number, bool], self.name)
if isinstance(to, type(tensor)):
to = to.element_type()
np_type = dtype_to_pytype(to)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册