提交 2484cd27 编写于 作者: M Megvii Engine Team

fix(tensor): check args when construct tensor with existing tensor

GitOrigin-RevId: 03454540707f42d409fdfdf88b5c044c56cf43b5
上级 e7587617
...@@ -97,7 +97,7 @@ class Optimizer(metaclass=ABCMeta): ...@@ -97,7 +97,7 @@ class Optimizer(metaclass=ABCMeta):
"optimizer can only optimize Parameters, but one of the params is " "optimizer can only optimize Parameters, but one of the params is "
+ str(type(param)) + str(type(param))
) )
param._reset(Tensor(param.numpy(), no_cache=True, format=param.format)) param[...] = Tensor(param.numpy(), no_cache=True)
for name, default in self._defaults.items(): for name, default in self._defaults.items():
if default is required and name not in param_group: if default is required and name not in param_group:
......
...@@ -525,7 +525,34 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { ...@@ -525,7 +525,34 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
} }
mgb_assert(tup.size() == 7); mgb_assert(tup.size() == 7);
if (auto* t = try_cast(tup[0].ptr())) { if (auto* t = try_cast(tup[0].ptr())) {
m_tensor = t->m_tensor->copy(); m_tensor = t->m_tensor;
// TODO: merge two path in arg parse
if (!tup[1].is_none()) {
auto dtype = tup[1].cast<DType>();
mgb_assert(
dtype == m_tensor->dtype(), "dtype mismatch: %s vs %s",
dtype.name(), m_tensor->dtype().name());
}
if (!tup[2].is_none()) {
auto device = as_comp_node(tup[2]);
mgb_assert(
device == m_tensor->comp_node(), "device mismatch: %s vs %s",
device.to_string().c_str(),
m_tensor->comp_node().to_string().c_str());
}
mgb_assert(!tup[3].cast<bool>(), "expect is_const == False, got True");
bool no_cache = tup[4].cast<bool>();
if (no_cache) {
// always copy because it's hard to tell whether this tensor is cached
m_tensor = m_tensor->copy();
}
// ignore name
if (!tup[6].is_none()) {
Format format = tup[6].cast<std::string>();
mgb_assert(
format == m_tensor->format(), "format mismatch: %s vs %s",
format.to_string().c_str(), m_tensor->format().to_string().c_str());
}
} else { } else {
auto data = tup[0]; auto data = tup[0];
DType dtype = tup[1].cast<DType>(); DType dtype = tup[1].cast<DType>();
...@@ -1030,7 +1057,7 @@ void init_tensor(py::module m) { ...@@ -1030,7 +1057,7 @@ void init_tensor(py::module m) {
try { try {
self.compiled->compile(); self.compiled->compile();
} catch (const std::exception& e) { } catch (const std::exception& e) {
mgb_log_error(e.what()); mgb_log_error("error in trace: %s", e.what());
} }
} }
// register transformations // register transformations
......
...@@ -47,7 +47,7 @@ public: ...@@ -47,7 +47,7 @@ public:
~Tensor() = default; ~Tensor() = default;
inline Tensor copy() { return *this; } inline Tensor copy() { return Tensor(imperative::apply(DupTensor(), data())[0]); }
inline DType dtype() { return *data().dtype(); } inline DType dtype() { return *data().dtype(); }
inline CompNode comp_node() { return *data().device(); } inline CompNode comp_node() { return *data().device(); }
......
...@@ -5,7 +5,9 @@ import numpy as np ...@@ -5,7 +5,9 @@ import numpy as np
import pytest import pytest
from utils import get_var_value, make_tensor from utils import get_var_value, make_tensor
from megengine import _full_sync
from megengine.core.tensor.dtype import get_scale, get_zero_point, qint8, quint8 from megengine.core.tensor.dtype import get_scale, get_zero_point, qint8, quint8
from megengine.device import get_default_device
from megengine.tensor import Parameter, Tensor from megengine.tensor import Parameter, Tensor
from megengine.utils.network import Network from megengine.utils.network import Network
...@@ -220,3 +222,16 @@ def test_tensor_from_bool(): ...@@ -220,3 +222,16 @@ def test_tensor_from_bool():
assert x.dtype == np.bool_ assert x.dtype == np.bool_
x = Tensor([True, False]) x = Tensor([True, False])
assert x.dtype == np.bool_ assert x.dtype == np.bool_
def test_tensor_construct_tensor():
x = Tensor(0, dtype=np.float32, device="xpu0:1", name="MyName")
assert Tensor(x.astype(np.int32)).dtype == np.int32
with pytest.raises(RuntimeError):
Tensor(x.astype(np.int32), dtype=np.float32)
assert Tensor(x).name == ""
assert Tensor(x, name="MyName2").name == "MyName2"
with pytest.raises(RuntimeError):
assert Tensor(x.to("xpu0:2"), device="xpu0:1").device == "xpu0:1"
assert Tensor(x.to("xpu0:2")).device == x.to("xpu0:2").device
_full_sync()
...@@ -126,6 +126,11 @@ ValueRefList InterpreterTransformation::apply_transformation( ...@@ -126,6 +126,11 @@ ValueRefList InterpreterTransformation::apply_transformation(
} else { } else {
return {ValueRef()}; return {ValueRef()};
} }
} else if (op.is<DupTensor>()) {
auto& input = inputs[0].cast(m_value_type);
DeviceTensorND dev_tensor;
dev_tensor.copy_from(m_channel->get_dev_tensor(input.handle()->handle()));
return m_value_type.make(share_handle(m_channel->put(dev_tensor, {})));
} else { } else {
return op.fallback(inputs); return op.fallback(inputs);
} }
......
...@@ -196,5 +196,10 @@ public: ...@@ -196,5 +196,10 @@ public:
std::string to_string() const override; std::string to_string() const override;
}; };
class DupTensor final : public OperatorImpl<DupTensor, Operator::IdentityLike> {
public:
std::string to_string() const override { return "DupTensor"; }
};
} // namespace imperative } // namespace imperative
} // namespace mgb } // namespace mgb
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册