提交 4b3c98cc 编写于 作者: H huangdongrun

add back assignvalue

上级 044b2146
......@@ -272,7 +272,17 @@ bool Tensor::operator==(const Tensor &tensor) const {
bool Tensor::ValueEqual(const Tensor &tensor) const {
return (&tensor == this || (MetaTensor::operator==(tensor) && data_->equals(*tensor.data_)));
}
// assgin value to this tensor
Tensor &Tensor::AssignValue(const Tensor &tensor) {
if (this != &tensor) {
MetaTensor::operator=(tensor);
dirty_ = tensor.is_dirty();
device_address_ = tensor.device_address();
data_ = tensor.data_;
id_ = tensor.id();
}
return *this;
}
abstract::AbstractBasePtr Tensor::ToAbstract() {
auto tens = shared_from_base<Tensor>();
auto dtype = tens->Dtype();
......
......@@ -147,6 +147,9 @@ class Tensor : public MetaTensor {
// it do real value comparison.
bool ValueEqual(const Tensor &tensor) const;
// assgin value to this tensor
Tensor &AssignValue(const Tensor &tensor);
bool operator==(const Value &other) const override {
if (other.isa<Tensor>()) {
auto &other_ = static_cast<const Tensor &>(other);
......
......@@ -327,6 +327,19 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
>>> data.dim()
2
)mydelimiter")
.def("assign_value", &Tensor::AssignValue, R"mydelimiter(
Assign another tensor value to this.
Arg:
value (:class:`mindspore.tensor`): The value tensor.
Examples:
>>> data = mindspore.Tensor(np.ones((1, 2), np.float32))
>>> data2 = mindspore.Tensor(np.ones((2, 2), np.float32))
>>> data.assign_value(data2)
>>> data.shape
(2, 2)
)mydelimiter")
.def("set_dtype", &Tensor::SetDtype, R"mydelimiter(
Set the tensor's data type.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册