提交 5381f945 编写于 作者: M Megvii Engine Team

feat(imperative): add error message when the operand of elemwise is None

GitOrigin-RevId: 4235e58fc29760e9534a63739f5610b0f6b22a57
上级 ffe7ceb3
...@@ -101,7 +101,10 @@ PyObject* py_apply( ...@@ -101,7 +101,10 @@ PyObject* py_apply(
HostTensorND ht(target_cn); HostTensorND ht(target_cn);
ht = npy::np2tensor(args[i], npy::Meth::copy_into(&ht), target_dtype); ht = npy::np2tensor(args[i], npy::Meth::copy_into(&ht), target_dtype);
record_py_backtrace(); record_py_backtrace();
if (PyArray_Check(args[i]) || PyList_Check(args[i])) { // non scaler //! operand in elemwise can't be None
if (args[i] == Py_None) {
throw py::type_error("the operand is None and is not supported.");
} else if (PyArray_Check(args[i]) || PyList_Check(args[i])) { // non scaler
// py_tuple is not allowed here because of tracing // py_tuple is not allowed here because of tracing
return imperative::apply( return imperative::apply(
CreateTensor(CreateTensor::Const, target_cn, ht.layout()), CreateTensor(CreateTensor::Const, target_cn, ht.layout()),
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import copy import copy
import unittest
import numpy as np import numpy as np
import pytest import pytest
...@@ -235,3 +236,11 @@ def test_tensor_construct_tensor(): ...@@ -235,3 +236,11 @@ def test_tensor_construct_tensor():
assert Tensor(x.to("xpu0:2"), device="xpu0:1").device == "xpu0:1" assert Tensor(x.to("xpu0:2"), device="xpu0:1").device == "xpu0:1"
assert Tensor(x.to("xpu0:2")).device == x.to("xpu0:2").device assert Tensor(x.to("xpu0:2")).device == x.to("xpu0:2").device
_full_sync() _full_sync()
class TestElemwiseNone(unittest.TestCase):
def test_elemementwise_and_with_none(self):
with self.assertRaises(TypeError) as context:
a = Tensor(1.0)
b = a + None
assert str(context.exception) == "the operand is None and is not supported."
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册