From 070c811732ec52ae8dba083206f730de614f1c0b Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 7 Jul 2021 10:15:13 +0800 Subject: [PATCH] fix(imperative): remove convert_inputs GitOrigin-RevId: a3c43db746789bb9ba418c91578c16368a96536c --- .../megengine/core/tensor/array_method.py | 10 +++--- .../python/megengine/functional/elemwise.py | 2 +- .../python/megengine/functional/math.py | 16 +++++---- imperative/python/megengine/functional/nn.py | 35 +++---------------- .../python/megengine/functional/tensor.py | 30 ++++++++-------- .../python/megengine/functional/vision.py | 13 ++++--- imperative/python/test/integration/test_bn.py | 11 +++--- .../python/test/integration/test_converge.py | 5 ++- .../test_converge_with_gradient_clip.py | 5 ++- .../test_converge_with_swap_and_drop.py | 5 ++- .../python/test/unit/core/test_interpreter.py | 2 +- .../test/unit/functional/test_elemwise.py | 13 ++++++- .../test/unit/functional/test_functional.py | 24 ++++++------- .../test/unit/optimizer/test_clip_grad.py | 5 +-- 14 files changed, 81 insertions(+), 95 deletions(-) diff --git a/imperative/python/megengine/core/tensor/array_method.py b/imperative/python/megengine/core/tensor/array_method.py index aed4cb16d..1fbd4b877 100644 --- a/imperative/python/megengine/core/tensor/array_method.py +++ b/imperative/python/megengine/core/tensor/array_method.py @@ -13,7 +13,7 @@ from typing import Union import numpy as np from .._imperative_rt.common import CompNode -from .._imperative_rt.core2 import SymbolVar, Tensor, apply +from .._imperative_rt.core2 import SymbolVar, Tensor, apply, dtype_promotion from ..ops import builtin from . import amp from .indexing import getitem, setitem @@ -81,7 +81,11 @@ def _matmul(inp1, inp2): inp1, inp2 = cast_tensors(inp1, inp2) else: compute_mode = "default" - inp1, inp2 = convert_inputs(inp1, inp2) + dtype = dtype_promotion(inp1, inp2) + if inp1.dtype != dtype: + inp1 = inp1.astype(dtype) + if inp2.dtype != dtype: + inp2 = inp2.astype(dtype) op = builtin.MatrixMul( transposeA=False, transposeB=False, compute_mode=compute_mode, format="default" ) @@ -91,7 +95,6 @@ def _matmul(inp1, inp2): def _transpose(data, axes): op = builtin.Dimshuffle(axes) - (data,) = convert_inputs(data) (result,) = apply(op, data) return result @@ -201,7 +204,6 @@ def _remove_axis(inp: Tensor, axis) -> Tensor: def _reduce(mode): def f(self, axis=None, keepdims: bool = False): data = self - (data,) = convert_inputs(data) if mode == "mean": data = data.astype("float32") elif self.dtype == np.bool_: diff --git a/imperative/python/megengine/functional/elemwise.py b/imperative/python/megengine/functional/elemwise.py index a89ce790a..2b9bd6103 100644 --- a/imperative/python/megengine/functional/elemwise.py +++ b/imperative/python/megengine/functional/elemwise.py @@ -13,7 +13,7 @@ from ..core._imperative_rt.core2 import SymbolVar, apply from ..core.ops import builtin from ..core.ops.builtin import Elemwise from ..core.tensor.array_method import _elwise -from ..core.tensor.utils import astype, convert_inputs +from ..core.tensor.utils import convert_inputs from ..tensor import Tensor from ..utils.deprecation import deprecated_func diff --git a/imperative/python/megengine/functional/math.py b/imperative/python/megengine/functional/math.py index 05165bf5a..0488e257d 100644 --- a/imperative/python/megengine/functional/math.py +++ b/imperative/python/megengine/functional/math.py @@ -10,16 +10,16 @@ import collections import math from typing import Optional, Sequence, Tuple, Union -from ..core._imperative_rt.core2 import apply +from ..core._imperative_rt.core2 import apply, dtype_promotion from ..core._trace_option import use_symbolic_shape from ..core.ops import builtin from ..core.ops.special import Const from ..core.tensor import amp -from ..core.tensor.utils import _normalize_axis, cast_tensors, convert_inputs, setscalar +from ..core.tensor.utils import _normalize_axis, cast_tensors, setscalar from ..tensor import Tensor from .debug_param import get_execution_strategy -from .elemwise import clip, exp, log, log1p -from .tensor import broadcast_to, concat, expand_dims, reshape, squeeze +from .elemwise import clip +from .tensor import broadcast_to, concat, expand_dims, squeeze __all__ = [ "argmax", @@ -816,10 +816,13 @@ def matmul( compute_mode = "float32" inp1, inp2 = cast_tensors(inp1, inp2) else: - inp1, inp2 = convert_inputs(inp1, inp2) + dtype = dtype_promotion(inp1, inp2) + if inp1.dtype != dtype: + inp1 = inp1.astype(dtype) + if inp2.dtype != dtype: + inp2 = inp2.astype(dtype) remove_row, remove_col = False, False - dim1, dim2 = inp1.ndim, inp2.ndim # handle dim=1 cases, dot and matrix-vector multiplication if dim1 == 1 and dim2 == 1: @@ -931,7 +934,6 @@ def dot(inp1: Tensor, inp2: Tensor) -> Tensor: """ op = builtin.Dot() - inp1, inp2 = convert_inputs(inp1, inp2) assert ( inp1.ndim <= 1 and inp2.ndim <= 1 ), "Input tensors for dot must be 1-dimensional or scalar" diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index c6c00aad8..e506b03f7 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -10,8 +10,6 @@ from typing import Optional, Sequence, Tuple, Union from ..core._imperative_rt.core2 import apply -from ..core._imperative_rt.graph import VarNode -from ..core._trace_option import use_symbolic_shape from ..core.ops import builtin from ..core.ops.builtin import BatchNorm, Elemwise from ..core.ops.special import Const @@ -21,7 +19,6 @@ from ..core.tensor.utils import ( astensor1d, astype, cast_tensors, - convert_inputs, convert_single_value, setscalar, ) @@ -33,18 +30,9 @@ from ..utils.deprecation import deprecated_func from ..utils.tuple_function import _pair, _pair_nonzero, _triple, _triple_nonzero from .debug_param import get_execution_strategy from .distributed import all_reduce_sum -from .elemwise import _elwise, exp, floor, log, log1p, maximum, minimum -from .math import argsort, matmul, max, prod, sum -from .tensor import ( - broadcast_to, - concat, - expand_dims, - full, - ones, - reshape, - squeeze, - zeros, -) +from .elemwise import _elwise, exp, log, log1p, maximum, minimum +from .math import matmul, max, sum +from .tensor import broadcast_to, concat, expand_dims, ones, squeeze, zeros __all__ = [ "adaptive_avg_pool2d", @@ -167,8 +155,6 @@ def conv1d( if amp._enabled: compute_mode = "float32" inp, weight, bias = cast_tensors(inp, weight, bias) - else: - inp, weight = convert_inputs(inp, weight) inp = expand_dims(inp, 3) weight = expand_dims(weight, 3) @@ -246,8 +232,6 @@ def conv2d( if amp._enabled: compute_mode = "float32" inp, weight, bias = cast_tensors(inp, weight, bias) - else: - inp, weight = convert_inputs(inp, weight) stride_h, stride_w = expand_hw(stride) pad_h, pad_w = expand_hw(padding) @@ -304,7 +288,6 @@ def conv3d( :return: output tensor. """ assert conv_mode.lower() == "cross_correlation" - inp, weight = convert_inputs(inp, weight) D, H, W = 0, 1, 2 @@ -379,8 +362,6 @@ def conv_transpose2d( if amp._enabled: compute_mode = "float32" inp, weight, bias = cast_tensors(inp, weight, bias) - else: - inp, weight = convert_inputs(inp, weight) if groups != 1: raise NotImplementedError("group transposed conv2d is not supported yet.") @@ -454,7 +435,8 @@ def deformable_conv2d( compute_mode = "float32" inp, weight, offset, mask, bias = cast_tensors(inp, weight, offset, mask, bias) else: - inp, weight, offset, mask = convert_inputs(inp, weight, offset, mask) + offset = offset.astype("float32") + mask = mask.astype("float32") stride_h, stride_w = expand_hw(stride) pad_h, pad_w = expand_hw(padding) @@ -493,7 +475,6 @@ def local_conv2d( conv_mode.lower() == "cross_correlation" or conv_mode.name == "CROSS_CORRELATION" ) - inp, weight = convert_inputs(inp, weight) stride_h, stride_w = expand_hw(stride) pad_h, pad_w = expand_hw(padding) @@ -539,8 +520,6 @@ def conv_transpose3d( :param dilation: dilation of the 3D convolution operation. Default: 1 :return: output tensor. """ - inp, weight = convert_inputs(inp, weight) - D, H, W = 0, 1, 2 pad = _triple(padding) stride = _triple_nonzero(stride) @@ -1078,10 +1057,6 @@ def batch_norm( weight, bias, running_mean, running_var = cast_tensors( weight, bias, running_mean, running_var, promote=True ) - elif compute_mode != "float32": - inp, weight, bias, running_mean, running_var = convert_inputs( - inp, weight, bias, running_mean, running_var - ) weight = make_full_if_none(weight, 1) bias = make_full_if_none(bias, 0) diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index f87dbb81d..ae07da0af 100644 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -6,25 +6,18 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -import math from typing import Iterable, Optional, Sequence, Union import numpy as np from ..core._imperative_rt import CompNode -from ..core._imperative_rt.core2 import SymbolVar, apply +from ..core._imperative_rt.core2 import SymbolVar, apply, dtype_promotion from ..core._wrap import as_device from ..core.ops import builtin from ..core.ops.builtin import Copy, Identity from ..core.ops.special import Const from ..core.tensor.array_method import _broadcast, _remove_axis -from ..core.tensor.utils import ( - astensor1d, - convert_inputs, - convert_single_value, - dtype_promotion, - get_device, -) +from ..core.tensor.utils import astensor1d, convert_inputs, get_device from ..device import get_default_device from ..tensor import Tensor from .elemwise import ceil, floor_div @@ -288,6 +281,7 @@ def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor: if len(inps) == 1: return inps[0] + # FIXME: remove this convert_inputs inps = convert_inputs(*inps, device=device) if device is None: device = get_device(inps) @@ -640,6 +634,7 @@ def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor: .. testcode:: + import numpy as np from megengine import tensor import megengine.functional as F mask = tensor(np.array([[True, False], [False, True]], dtype=np.bool)) @@ -657,7 +652,6 @@ def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor: [7. 4.]] """ - x, y = convert_inputs(x, y) if not isinstance(x, Tensor): raise TypeError("input x must be a tensor") if not isinstance(y, Tensor): @@ -669,6 +663,12 @@ def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor: if x.device != mask.device: raise ValueError("ambiguous device: {} vs {}".format(x.device, mask.device)) + dtype = dtype_promotion(x, y) + if x.dtype != dtype: + x = x.astype(dtype) + if y.dtype != dtype: + y = y.astype(dtype) + v0, index0 = cond_take(mask, x) v1, index1 = cond_take(~mask, y) @@ -1021,12 +1021,10 @@ def arange( if stop is None: start, stop = 0, start - if isinstance(start, Tensor): - start = start.astype("float32") - if isinstance(stop, Tensor): - stop = stop.astype("float32") - if isinstance(step, Tensor): - step = step.astype("float32") + start = Tensor(start, dtype="float32") + stop = Tensor(stop, dtype="float32") + step = Tensor(step, dtype="float32") + num = ceil((stop - start) / step) stop = start + step * (num - 1) result = linspace(start, stop, num, device=device) diff --git a/imperative/python/megengine/functional/vision.py b/imperative/python/megengine/functional/vision.py index f1427cee9..14044df1f 100644 --- a/imperative/python/megengine/functional/vision.py +++ b/imperative/python/megengine/functional/vision.py @@ -8,6 +8,8 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. from typing import Iterable, Optional, Tuple, Union +import numpy as np + from ..core._imperative_rt.core2 import apply from ..core.ops import builtin from ..core.tensor import megbrain_graph, utils @@ -98,7 +100,6 @@ def roi_pooling( output_shape = (output_shape, output_shape) op = builtin.ROIPooling(mode=mode, scale=scale) - inp, rois = utils.convert_inputs(inp, rois) result, _ = apply( op, inp, rois, Tensor(output_shape, dtype="int32", device=inp.device) ) @@ -187,6 +188,8 @@ def roi_align( [0.1359 0.1359]]] """ + if inp.dtype != np.float32: + inp = inp.astype(np.float32) mode = mode.lower() assert mode in ["max", "average"], "only max/average mode is supported" if isinstance(output_shape, int): @@ -207,7 +210,6 @@ def roi_align( sample_height=sample_height, sample_width=sample_width, ) - inp, rois = utils.convert_inputs(inp, rois) result, *_ = apply(op, inp, rois) return result @@ -270,7 +272,7 @@ def nms( max_output = boxes.shape[0] op = builtin.NMSKeep(iou_thresh, max_output) - inp = utils.convert_inputs(boxes.reshape(1, -1, 4)) + inp = (boxes.reshape(1, -1, 4),) indices, count = apply(op, *inp) indices = indices[0][: count[0]] keep_inds = sorted_idx[indices] @@ -442,10 +444,13 @@ def warp_perspective( [ 9. 10.]]]] """ + if inp.dtype == np.float32: + mat = mat.astype("float32") + if inp.dtype == np.float16: + inp = inp.astype("float32") op = builtin.WarpPerspective( imode=interp_mode, bmode=border_mode, format=format, border_val=border_val ) - inp, mat = utils.convert_inputs(inp, mat) out_shape = astensor1d(out_shape, inp, dtype="int32", device=inp.device) if mat_idx is not None: mat_idx = astensor1d(mat_idx, inp, dtype="int32", device=inp.device) diff --git a/imperative/python/test/integration/test_bn.py b/imperative/python/test/integration/test_bn.py index 38816f2bd..a1cd96280 100644 --- a/imperative/python/test/integration/test_bn.py +++ b/imperative/python/test/integration/test_bn.py @@ -14,8 +14,7 @@ import megengine.autodiff as ad import megengine.distributed as dist import megengine.functional as F import megengine.optimizer as optimizer -from megengine import Parameter, tensor -from megengine.distributed.helper import get_device_count_by_fork +from megengine import tensor from megengine.jit import trace from megengine.module import BatchNorm2d, Conv2d, Module, Sequential, SyncBatchNorm @@ -88,7 +87,7 @@ def test_bn_no_track_stat(): optim = optimizer.SGD(m.parameters(), lr=1.0) optim.clear_grad() - data = np.random.random((6, nchannel, 2, 2)).astype("float32") + data = tensor(np.random.random((6, nchannel, 2, 2)).astype("float32")) with gm: loss = m(data).sum() gm.backward(loss) @@ -110,7 +109,7 @@ def test_bn_no_track_stat2(): optim = optimizer.SGD(m.parameters(), lr=1.0) optim.clear_grad() - data = np.random.random((6, nchannel, 2, 2)).astype("float32") + data = tensor(np.random.random((6, nchannel, 2, 2)).astype("float32")) with gm: loss = m(data).sum() gm.backward(loss) @@ -146,7 +145,7 @@ def test_trace_bn_forward_twice(): pred = net(inp) return pred - x = np.ones((1, 1, 32, 32), dtype=np.float32) + x = tensor(np.ones((1, 1, 32, 32), dtype=np.float32)) y = train_bn(x, net=Simple()) np.testing.assert_equal(y.numpy(), 0) @@ -194,5 +193,5 @@ def test_trace_several_syncbn(trace_mode): def test_frozen_bn_no_affine(): nchannel = 3 m = BatchNorm2d(nchannel, freeze=True, affine=False) - data = megengine.Tensor(np.random.random((6, nchannel, 2, 2)).astype("float32")) + data = tensor(np.random.random((6, nchannel, 2, 2)).astype("float32")) m(data).numpy() diff --git a/imperative/python/test/integration/test_converge.py b/imperative/python/test/integration/test_converge.py index 35de394b7..08c0cb6d3 100644 --- a/imperative/python/test/integration/test_converge.py +++ b/imperative/python/test/integration/test_converge.py @@ -9,7 +9,6 @@ import itertools import numpy as np -import pytest import megengine as mge import megengine.autodiff as ad @@ -105,10 +104,10 @@ def test_training_converge(): xx, yy = np.meshgrid(x, x) xx = xx.reshape((ngrid * ngrid, 1)) yy = yy.reshape((ngrid * ngrid, 1)) - data = np.concatenate((xx, yy), axis=1).astype(np.float32) + data = mge.tensor(np.concatenate((xx, yy), axis=1).astype(np.float32)) pred = infer(data).numpy() - precision = calculate_precision(data, pred) + precision = calculate_precision(data.numpy(), pred) assert precision == 1.0, "Test precision must be high enough, get {}".format( precision ) diff --git a/imperative/python/test/integration/test_converge_with_gradient_clip.py b/imperative/python/test/integration/test_converge_with_gradient_clip.py index 1c0f3ac05..fd6c642b9 100644 --- a/imperative/python/test/integration/test_converge_with_gradient_clip.py +++ b/imperative/python/test/integration/test_converge_with_gradient_clip.py @@ -9,7 +9,6 @@ import itertools import numpy as np -import pytest import megengine as mge import megengine.autodiff as ad @@ -110,10 +109,10 @@ def test_training_converge(): xx, yy = np.meshgrid(x, x) xx = xx.reshape((ngrid * ngrid, 1)) yy = yy.reshape((ngrid * ngrid, 1)) - data = np.concatenate((xx, yy), axis=1).astype(np.float32) + data = mge.tensor(np.concatenate((xx, yy), axis=1).astype(np.float32)) pred = infer(data).numpy() - precision = calculate_precision(data, pred) + precision = calculate_precision(data.numpy(), pred) print("precision=", precision) assert precision == 1.0, "Test precision must be high enough, get {}".format( precision diff --git a/imperative/python/test/integration/test_converge_with_swap_and_drop.py b/imperative/python/test/integration/test_converge_with_swap_and_drop.py index 709301786..468b464af 100644 --- a/imperative/python/test/integration/test_converge_with_swap_and_drop.py +++ b/imperative/python/test/integration/test_converge_with_swap_and_drop.py @@ -9,7 +9,6 @@ import itertools import numpy as np -import pytest import megengine as mge import megengine.autodiff as ad @@ -118,10 +117,10 @@ def test_training_converge_with_swap_and_drop(): xx, yy = np.meshgrid(x, x) xx = xx.reshape((ngrid * ngrid, 1)) yy = yy.reshape((ngrid * ngrid, 1)) - data = np.concatenate((xx, yy), axis=1).astype(np.float32) + data = mge.tensor(np.concatenate((xx, yy), axis=1).astype(np.float32)) pred = infer(Tensor(data)).numpy() - precision = calculate_precision(data, pred) + precision = calculate_precision(data.numpy(), pred) assert precision == 1.0, "Test precision must be high enough, get {}".format( precision ) diff --git a/imperative/python/test/unit/core/test_interpreter.py b/imperative/python/test/unit/core/test_interpreter.py index 548190135..07db2a2a6 100644 --- a/imperative/python/test/unit/core/test_interpreter.py +++ b/imperative/python/test/unit/core/test_interpreter.py @@ -36,7 +36,7 @@ def test_level1_infer_value(): def test_level1_infer_shape_with_unknown(): config_async_level(2) a = mge.tensor([[1, 2, 2, 3]], dtype="float32") - b = mge.tensor([1, 1]) + b = mge.tensor([1, 1], dtype="float32") multi2 = mge.tensor(np.array([[2, 0], [0, 2]]), dtype="float32") c = F.matmul(b, multi2) # make DepType::SHAPE unknown diff --git a/imperative/python/test/unit/functional/test_elemwise.py b/imperative/python/test/unit/functional/test_elemwise.py index 25dce2436..cc45dc5fe 100644 --- a/imperative/python/test/unit/functional/test_elemwise.py +++ b/imperative/python/test/unit/functional/test_elemwise.py @@ -13,7 +13,7 @@ import megengine.functional as F import megengine.functional.elemwise as elemwise from megengine import tensor from megengine.core.tensor import dtype -from megengine.functional.elemwise import Elemwise, _elwise +from megengine.functional.elemwise import Elemwise from megengine.jit import trace @@ -57,6 +57,17 @@ def test_multiply(): ) +def test_div(): + np.testing.assert_allclose( + F.div(tensor([3, 4]), 2).numpy(), + np.divide(np.array([3, 4], dtype=np.float32), 2), + ) + + np.testing.assert_allclose( + (tensor([3, 4]) / 2).numpy(), np.divide(np.array([3, 4], dtype=np.float32), 2), + ) + + def test_clamp(): """Fix an issue when `lower` or `upper` is 0, it will be recognized as `False` and `F.clip` will fall into wrong conditions unexpectedly. diff --git a/imperative/python/test/unit/functional/test_functional.py b/imperative/python/test/unit/functional/test_functional.py index 977feb6bb..7bbdc6ad5 100644 --- a/imperative/python/test/unit/functional/test_functional.py +++ b/imperative/python/test/unit/functional/test_functional.py @@ -456,9 +456,10 @@ def test_interpolate_fastpath(): np.testing.assert_equal(out.item(), np_x.mean()) -def test_warp_perspective(): +@pytest.mark.parametrize("dt", [np.float32, np.int8, np.uint8, np.float16]) +def test_warp_perspective(dt): inp_shape = (1, 1, 4, 4) - x = tensor(np.arange(16, dtype=np.float32).reshape(inp_shape)) + x = tensor(np.arange(16, dtype=dt).reshape(inp_shape)) M_shape = (1, 3, 3) # M defines a translation: dst(1, 1, h, w) = rst(1, 1, h+1, w+1) M = tensor( @@ -467,14 +468,13 @@ def test_warp_perspective(): ).reshape(M_shape) ) outp = F.vision.warp_perspective(x, M, (2, 2)) - np.testing.assert_equal( - outp.numpy(), np.array([[[[5.0, 6.0], [9.0, 10.0]]]], dtype=np.float32) - ) + np.testing.assert_equal(outp.numpy(), np.array([[[[5, 6], [9, 10]]]], dtype=dt)) -def test_warp_perspective_mat_idx(): +@pytest.mark.parametrize("dt", [np.float32, np.int8, np.uint8, np.float16]) +def test_warp_perspective_mat_idx(dt): inp_shape = (2, 1, 4, 4) - x = tensor(np.arange(32, dtype=np.float32).reshape(inp_shape)) + x = tensor(np.arange(32, dtype=dt).reshape(inp_shape)) M_shape = (1, 3, 3) # M defines a translation: dst(1, 1, h, w) = rst(1, 1, h+1, w+1) M = tensor( @@ -488,12 +488,12 @@ def test_warp_perspective_mat_idx(): outp.numpy(), np.array( [ - [[[5.0, 6.0], [9.0, 10.0]]], - [[[21.0, 22.0], [25.0, 26.0]]], - [[[21.0, 22.0], [25.0, 26.0]]], - [[[5.0, 6.0], [9.0, 10.0]]], + [[[5, 6], [9, 10]]], + [[[21, 22], [25, 26]]], + [[[21, 22], [25, 26]]], + [[[5, 6], [9, 10]]], ], - dtype=np.float32, + dtype=dt, ), ) diff --git a/imperative/python/test/unit/optimizer/test_clip_grad.py b/imperative/python/test/unit/optimizer/test_clip_grad.py index 63dca8323..9864dfae8 100644 --- a/imperative/python/test/unit/optimizer/test_clip_grad.py +++ b/imperative/python/test/unit/optimizer/test_clip_grad.py @@ -5,11 +5,8 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -import platform -import weakref import numpy as np -import pytest import megengine as mge import megengine.autodiff as ad @@ -65,7 +62,7 @@ def test_clip_grad_value(): gm = ad.GradManager().attach(net.parameters()) opt = optim.SGD(net.parameters(), 1e-3, momentum=0.9) with gm: - y = net(x) + y = net(mge.tensor(x)) y = y.mean() gm.backward(y) save_grad_value(net) -- GitLab