提交 90cedd3c 编写于 作者: M Megvii Engine Team

fix(imperative): restrict using convert_inputs in py_apply

GitOrigin-RevId: b021aac8a6f35dfe3b87dbc98a8007fd1a5b54b2
上级 7a63f1cd
...@@ -169,7 +169,7 @@ PyObject* py_apply( ...@@ -169,7 +169,7 @@ PyObject* py_apply(
} }
HostTensorND ht(target_cn); HostTensorND ht(target_cn);
ht = npy::np2tensor(args[i], npy::Meth::copy_into(&ht), target_dtype); ht = npy::np2tensor(args[i], npy::Meth::copy_into(&ht), target_dtype);
if (PyArray_Check(args[i])) { // non scaler if (PyArray_Check(args[i]) || PyList_Check(args[i])) { // non scaler
return imperative::apply( return imperative::apply(
CreateTensor(CreateTensor::Const, target_cn, ht.layout()), CreateTensor(CreateTensor::Const, target_cn, ht.layout()),
HostStorage::make(ht.storage()))[0]; HostStorage::make(ht.storage()))[0];
...@@ -205,8 +205,13 @@ PyObject* py_apply( ...@@ -205,8 +205,13 @@ PyObject* py_apply(
for (size_t i = 0; i < nargs; ++i) { for (size_t i = 0; i < nargs; ++i) {
if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) { if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) {
tensors[i] = tw->m_tensor->data(); tensors[i] = tw->m_tensor->data();
} else { } else if (
DTypePromoteCfg::convert_input_enabled &&
op->same_type<Elemwise>()) {
tensors[i] = convert_pyinput_to_tensor(i); tensors[i] = convert_pyinput_to_tensor(i);
} else {
PyErr_SetString(PyExc_TypeError, "py_apply expects tensor as inputs");
return nullptr;
} }
} }
......
...@@ -77,6 +77,11 @@ def test_div(): ...@@ -77,6 +77,11 @@ def test_div():
np.floor_divide(np.array([-5, -7], dtype=np.int32), 2), np.floor_divide(np.array([-5, -7], dtype=np.int32), 2),
) )
np.testing.assert_allclose(
(tensor([[5, 4, 3], [4, 2, 6]]) // [1, 2, 1]).numpy(),
np.floor_divide(np.array([[5, 4, 3], [4, 2, 6]], dtype=np.int32), [1, 2, 1]),
)
def test_clamp(): def test_clamp():
"""Fix an issue when `lower` or `upper` is 0, it will be recognized as `False` and """Fix an issue when `lower` or `upper` is 0, it will be recognized as `False` and
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册