diff --git a/mindspore/ccsrc/ir/tensor.cc b/mindspore/ccsrc/ir/tensor.cc index c43f5423f1b3ba55310848af6a4e41a61864562b..673a8da8423ec46108613535b01bbaa3633e6d95 100644 --- a/mindspore/ccsrc/ir/tensor.cc +++ b/mindspore/ccsrc/ir/tensor.cc @@ -272,7 +272,17 @@ bool Tensor::operator==(const Tensor &tensor) const { bool Tensor::ValueEqual(const Tensor &tensor) const { return (&tensor == this || (MetaTensor::operator==(tensor) && data_->equals(*tensor.data_))); } - +// assgin value to this tensor +Tensor &Tensor::AssignValue(const Tensor &tensor) { + if (this != &tensor) { + MetaTensor::operator=(tensor); + dirty_ = tensor.is_dirty(); + device_address_ = tensor.device_address(); + data_ = tensor.data_; + id_ = tensor.id(); + } + return *this; +} abstract::AbstractBasePtr Tensor::ToAbstract() { auto tens = shared_from_base(); auto dtype = tens->Dtype(); diff --git a/mindspore/ccsrc/ir/tensor.h b/mindspore/ccsrc/ir/tensor.h index cd308271294523d05494adad6c782fa1a3159c01..5be8a063c1119ed0e80f899a222b9e58f877e95e 100644 --- a/mindspore/ccsrc/ir/tensor.h +++ b/mindspore/ccsrc/ir/tensor.h @@ -147,6 +147,9 @@ class Tensor : public MetaTensor { // it do real value comparison. bool ValueEqual(const Tensor &tensor) const; + // assgin value to this tensor + Tensor &AssignValue(const Tensor &tensor); + bool operator==(const Value &other) const override { if (other.isa()) { auto &other_ = static_cast(other); diff --git a/mindspore/ccsrc/ir/tensor_py.cc b/mindspore/ccsrc/ir/tensor_py.cc index 1c763d48f462f31438f4a54df77dda56279c1786..11a000cef7d3c83096da051f5663f08001233167 100644 --- a/mindspore/ccsrc/ir/tensor_py.cc +++ b/mindspore/ccsrc/ir/tensor_py.cc @@ -327,6 +327,19 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) { >>> data.dim() 2 )mydelimiter") + .def("assign_value", &Tensor::AssignValue, R"mydelimiter( + Assign another tensor value to this. + + Arg: + value (:class:`mindspore.tensor`): The value tensor. + + Examples: + >>> data = mindspore.Tensor(np.ones((1, 2), np.float32)) + >>> data2 = mindspore.Tensor(np.ones((2, 2), np.float32)) + >>> data.assign_value(data2) + >>> data.shape + (2, 2) + )mydelimiter") .def("set_dtype", &Tensor::SetDtype, R"mydelimiter( Set the tensor's data type.