From 2950dd8d6981412663aa8bf670e815673bffd459 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 10 Aug 2021 14:25:11 +0800 Subject: [PATCH] test(imperative): test autodiff.Function with non tensor arguments GitOrigin-RevId: 6114f48d2188829ff64c9560003e020eb763a13d --- imperative/python/test/unit/core/test_function.py | 15 +++++++++++++++ imperative/src/impl/op_trait.h | 4 ++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/imperative/python/test/unit/core/test_function.py b/imperative/python/test/unit/core/test_function.py index 84153b58a..1906a985b 100644 --- a/imperative/python/test/unit/core/test_function.py +++ b/imperative/python/test/unit/core/test_function.py @@ -8,6 +8,7 @@ import copy import numpy as np +import pytest import megengine.autodiff as ad import megengine.functional as F @@ -303,3 +304,17 @@ def test_zero_grad(): np.testing.assert_almost_equal( net.a.numpy(), np.array([1.0 - 4.0], dtype=np.float32), ) + + +def test_throw_on_non_tensor_argument(): + class NonTensorArg(Function): + def forward(self, inp, c): + return inp + c + + def backward(self, grad): + return grad + + x = tensor(np.array([2.33], dtype=np.float32)) + func = NonTensorArg() + with pytest.raises(TypeError, match=r"op .* expect type Tensor as inputs"): + func(x, 1) diff --git a/imperative/src/impl/op_trait.h b/imperative/src/impl/op_trait.h index 93e5c3972..850f23466 100644 --- a/imperative/src/impl/op_trait.h +++ b/imperative/src/impl/op_trait.h @@ -108,7 +108,7 @@ struct OpMethNotImpl { struct OpMethFallback : public OpMethNotImpl { using OpMethNotImpl::impl; static void impl(ApplyOnPhysicalTensor& func, - op_meth_tag::ApplyOnPhysicalTensor); + op_meth_tag::ApplyOnPhysicalTensor); static void impl(Execute& func, op_meth_tag::Execute); static void impl(InferOutputMemDesc& func, op_meth_tag::InferOutputMemDesc); static void impl(InferOutputAttrsFallible& func, @@ -120,9 +120,9 @@ struct OpMethFallback : public OpMethNotImpl { template struct OpMeth : public thin_function { using Base = thin_function; - using Base::operator bool; OpMeth() : Base{}, allow_fallback(false){}; explicit OpMeth(const Base& base) { this->Base::operator=(base); } + using Base::operator bool; RType operator()(Args... args) const { if (!this->Base::operator bool()) { if (allow_fallback) { -- GitLab