diff --git a/imperative/python/test/unit/core/test_function.py b/imperative/python/test/unit/core/test_function.py index 84153b58a165b9e9ee903badb1ebe92ce88fc4af..1906a985b7967b8ad55a4fb34764dc53d701b465 100644 --- a/imperative/python/test/unit/core/test_function.py +++ b/imperative/python/test/unit/core/test_function.py @@ -8,6 +8,7 @@ import copy import numpy as np +import pytest import megengine.autodiff as ad import megengine.functional as F @@ -303,3 +304,17 @@ def test_zero_grad(): np.testing.assert_almost_equal( net.a.numpy(), np.array([1.0 - 4.0], dtype=np.float32), ) + + +def test_throw_on_non_tensor_argument(): + class NonTensorArg(Function): + def forward(self, inp, c): + return inp + c + + def backward(self, grad): + return grad + + x = tensor(np.array([2.33], dtype=np.float32)) + func = NonTensorArg() + with pytest.raises(TypeError, match=r"op .* expect type Tensor as inputs"): + func(x, 1) diff --git a/imperative/src/impl/op_trait.h b/imperative/src/impl/op_trait.h index 93e5c397250721cede911cd637a4e989751d5428..850f23466f2bfb460f6c453861002d3d077aea82 100644 --- a/imperative/src/impl/op_trait.h +++ b/imperative/src/impl/op_trait.h @@ -108,7 +108,7 @@ struct OpMethNotImpl { struct OpMethFallback : public OpMethNotImpl { using OpMethNotImpl::impl; static void impl(ApplyOnPhysicalTensor& func, - op_meth_tag::ApplyOnPhysicalTensor); + op_meth_tag::ApplyOnPhysicalTensor); static void impl(Execute& func, op_meth_tag::Execute); static void impl(InferOutputMemDesc& func, op_meth_tag::InferOutputMemDesc); static void impl(InferOutputAttrsFallible& func, @@ -120,9 +120,9 @@ struct OpMethFallback : public OpMethNotImpl { template struct OpMeth : public thin_function { using Base = thin_function; - using Base::operator bool; OpMeth() : Base{}, allow_fallback(false){}; explicit OpMeth(const Base& base) { this->Base::operator=(base); } + using Base::operator bool; RType operator()(Args... args) const { if (!this->Base::operator bool()) { if (allow_fallback) {