提交 62f7dc49 编写于 作者: B buxue

support convert bool scalar and tensor to number tensor

上级 32a72c19
......@@ -255,9 +255,6 @@ void DoAutoCast(const std::vector<Signature> &signature, const abstract::Abstrac
if (arg_value->isa<abstract::AbstractTensor>() && arg_type_id == it->second) {
continue;
}
if ((arg_type_id == kNumberTypeBool || it->second == kNumberTypeBool) && arg_type_id != it->second) {
continue;
}
(*op_inputs)[i + 1] = DoCast((*op_inputs)[i + 1], it->second, graph);
}
}
......
......@@ -101,10 +101,8 @@ def test_pow():
result = testpow(input_tensor, power)
assert np.all(result.asnumpy() == expect)
net = PowNet()
with pytest.raises(TypeError):
net(input_tensor, True)
with pytest.raises(TypeError):
net(input_tensor, power2)
net(input_tensor, True)
net(input_tensor, power2)
def test_exp():
......
......@@ -293,13 +293,6 @@ raise_set = [
'desc_inputs': [5.0],
'skip': ['backward']}),
# input x is Tensor(bool)
('Pow1', {
'block': (P.Pow(),
{'exception': TypeError, 'error_keywords': ['Pow']}),
'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.bool_)), 2.0],
'skip': ['backward']}),
# input is not Tensor
('Exp1', {
'block': (P.Exp(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册