提交 e6b06914 编写于 作者: M Megvii Engine Team

feat(mge/tensor): support non-Tensor value in `_reset` and remove depreciated tests

GitOrigin-RevId: faf6c78aa8f6d7c43c95dc174261cc5c5d9edac1
上级 5a38ad39
......@@ -16,7 +16,6 @@ from .._imperative_rt.common import CompNode
from .._imperative_rt.core2 import Tensor, apply
from ..ops import builtin
from ..ops.builtin import Elemwise, GetVarShape
from ..ops.special import Const
from . import utils
from .indexing import getitem as _getitem
from .indexing import setitem as _setitem
......
......@@ -119,6 +119,8 @@ class Tensor(_Tensor, ArrayMethodMixin):
return super().detach()
def _reset(self, other):
if not isinstance(other, _Tensor):
other = Tensor(other, dtype=self.dtype, device=self.device)
super()._reset(other)
def __repr__(self):
......@@ -141,8 +143,6 @@ class Tensor(_Tensor, ArrayMethodMixin):
@deprecated(version="1.0", reason="no need to reuse an existing tensor since 1.0")
def set_value(self, value):
if not isinstance(value, _Tensor):
value = Tensor(value, dtype=self.dtype, device=self.device)
self._reset(value)
@deprecated(version="1.0", reason="use *= 0 instead")
......
......@@ -50,6 +50,14 @@ def test_reduce():
test_x(np.array([True, False, True]))
def test_set_value():
v0 = np.random.random((2, 3)).astype(np.float32)
param = Tensor(v0)
v1 = np.random.random((2, 3)).astype(np.float32)
param[...] = v1
np.testing.assert_allclose(param.numpy(), v1, atol=5e-6)
def test_set_subtensor():
x = Tensor([1, 2, 3])
x[:] = [1, 1, 1]
......
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import copy
import numpy as np
import pytest
import megengine as mge
import megengine.functional as F
from megengine import Parameter, Tensor
from megengine.module import Conv2d
# TODO: delete this test after deleting set_value
def test_set_value():
v0 = np.random.random((2, 3)).astype(np.float32)
param = Parameter(v0)
v1 = np.random.random((2, 3)).astype(np.float32)
param.set_value(v1)
np.testing.assert_allclose(param.numpy(), v1, atol=5e-6)
v2 = np.random.random((3, 3)).astype(np.float32)
# TODO: add this
# with pytest.raises(ValueError):
# param.set_value(v2)
np.testing.assert_allclose(param.numpy(), v1, atol=5e-6)
@pytest.mark.skip(reason="fill unsupported")
def test_fill():
a = Tensor(np.zeros((2, 3), dtype=np.float32))
a.fill(3)
np.testing.assert_allclose(a.numpy(), np.full((2, 3), 3, dtype=np.float32))
a.fill(124.568)
np.testing.assert_allclose(a.numpy(), np.full((2, 3), 124.568, dtype=np.float32))
# TODO: remove or rewrite following test
# def test_attach():
# p_ = np.random.random((2, 3)).astype(np.float32)
# with Graph() as g:
# g.set_option('eager_evaluation', False)
# p = Parameter(p_)
# v = p * 2
# f = compile(v, None)
# out, = f()
# np.testing.assert_allclose(out, p_ * 2)
# F.add_update(p, p)
# out, = f()
# np.testing.assert_allclose(out, p_ * 4)
# TODO: remove or rewrite following test
# def test_module_attach():
# v = np.random.random((1, 3, 64, 64)).astype(np.float32)
# net = Conv2d(3, 16, 3)
# with Graph() as g:
# g.set_option('eager_evaluation', False)
# data0 = Input("data")
# f = compile(net(data0), None)
# out0, = f(data=v)
# data1 = Input("data", value=v)
# out1 = net(data1)
# np.testing.assert_allclose(out0, out1.numpy())
# def test_shape_warning():
# with Graph() as cg:
# cg.set_option("eager_evaluation", False)
# b = Tensor(np.ones((2, 3)).astype(np.float32))
# with pytest.warns(None) as record:
# print(b.shape)
# if len(record) != 0:
# raise ValueError(
# "Getting the shape of a constant Tensor should throw no Warning"
# )
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册