提交 2950dd8d 编写于 作者: M Megvii Engine Team

test(imperative): test autodiff.Function with non tensor arguments

GitOrigin-RevId: 6114f48d2188829ff64c9560003e020eb763a13d
上级 20e8541b
......@@ -8,6 +8,7 @@
import copy
import numpy as np
import pytest
import megengine.autodiff as ad
import megengine.functional as F
......@@ -303,3 +304,17 @@ def test_zero_grad():
np.testing.assert_almost_equal(
net.a.numpy(), np.array([1.0 - 4.0], dtype=np.float32),
)
def test_throw_on_non_tensor_argument():
class NonTensorArg(Function):
def forward(self, inp, c):
return inp + c
def backward(self, grad):
return grad
x = tensor(np.array([2.33], dtype=np.float32))
func = NonTensorArg()
with pytest.raises(TypeError, match=r"op .* expect type Tensor as inputs"):
func(x, 1)
......@@ -108,7 +108,7 @@ struct OpMethNotImpl {
struct OpMethFallback : public OpMethNotImpl {
using OpMethNotImpl::impl;
static void impl(ApplyOnPhysicalTensor& func,
op_meth_tag::ApplyOnPhysicalTensor);
op_meth_tag::ApplyOnPhysicalTensor);
static void impl(Execute& func, op_meth_tag::Execute);
static void impl(InferOutputMemDesc& func, op_meth_tag::InferOutputMemDesc);
static void impl(InferOutputAttrsFallible& func,
......@@ -120,9 +120,9 @@ struct OpMethFallback : public OpMethNotImpl {
template <typename Tag, typename RType, typename... Args>
struct OpMeth<Tag, RType(Args...)> : public thin_function<RType(Args...)> {
using Base = thin_function<RType(Args...)>;
using Base::operator bool;
OpMeth() : Base{}, allow_fallback(false){};
explicit OpMeth(const Base& base) { this->Base::operator=(base); }
using Base::operator bool;
RType operator()(Args... args) const {
if (!this->Base::operator bool()) {
if (allow_fallback) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册