提交 fabebb26 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5264 [bug]fix bugs when parameters updata

Merge pull request !5264 from vlne-v1/I1SP3I-return-value-not-the-exact-value
...@@ -751,19 +751,17 @@ py::object ExecutorPy::Run(const py::tuple &args, const py::object &phase) { ...@@ -751,19 +751,17 @@ py::object ExecutorPy::Run(const py::tuple &args, const py::object &phase) {
return ExecDFGraph(info_, args, phase_s); return ExecDFGraph(info_, args, phase_s);
} }
#else #else
if (backend == "ms" || backend == "ge") { auto ret_val = std::make_shared<py::object>();
auto ret_val = std::make_shared<py::object>(); if (info_.count(phase_s) != 0 && info_[phase_s]->func_graph != nullptr) {
if (info_.count(phase_s) != 0 && info_[phase_s]->func_graph != nullptr) { if (IsGraphOutputValueNodeOrParameter(info_[phase_s]->func_graph->output(), args, ret_val)) {
if (IsGraphOutputValueNodeOrParameter(info_[phase_s]->func_graph->output(), args, ret_val)) { return *ret_val;
return *ret_val;
}
} }
if (backend == "ge") { }
if (args.size() > 0) { if (backend == "ge") {
return args[0]; if (args.size() > 0) {
} return args[0];
return args;
} }
return args;
} }
#endif #endif
std::size_t full_arg_size = ArgListSize(phase_s); std::size_t full_arg_size = ArgListSize(phase_s);
......
...@@ -389,6 +389,8 @@ class Parameter(MetaTensor): ...@@ -389,6 +389,8 @@ class Parameter(MetaTensor):
raise RuntimeError("Must set or change parallel mode before any Initializer created.") raise RuntimeError("Must set or change parallel mode before any Initializer created.")
if self.init_mode is None: if self.init_mode is None:
return self return self
if self.inited_param is not None:
return self.inited_param
if layout is not None: if layout is not None:
if not isinstance(layout, list): if not isinstance(layout, list):
raise TypeError("The layout should be list! layout is {}.".format(layout)) raise TypeError("The layout should be list! layout is {}.".format(layout))
......
...@@ -36,8 +36,8 @@ abstract::AbstractBasePtr MetaTensor::ToAbstract() { ...@@ -36,8 +36,8 @@ abstract::AbstractBasePtr MetaTensor::ToAbstract() {
auto abs_tensor = std::make_shared<abstract::AbstractTensor>(dtype, tensor_shape); auto abs_tensor = std::make_shared<abstract::AbstractTensor>(dtype, tensor_shape);
// if is parameter always no value. // if is parameter always no value.
if (is_parameter()) { if (is_parameter_) {
auto param_name = param_info()->name(); auto param_name = param_info_->name();
auto ref_key = std::make_shared<RefKey>(param_name); auto ref_key = std::make_shared<RefKey>(param_name);
auto abs_ref_key = ref_key->ToAbstract(); auto abs_ref_key = ref_key->ToAbstract();
abs_tensor = std::make_shared<abstract::AbstractRef>(abs_ref_key, abs_tensor); abs_tensor = std::make_shared<abstract::AbstractRef>(abs_ref_key, abs_tensor);
......
...@@ -476,8 +476,8 @@ abstract::AbstractBasePtr Tensor::ToAbstract() { ...@@ -476,8 +476,8 @@ abstract::AbstractBasePtr Tensor::ToAbstract() {
auto tensor_shape = tens->shape(); auto tensor_shape = tens->shape();
auto abs_tensor = std::make_shared<abstract::AbstractTensor>(dtype, tensor_shape); auto abs_tensor = std::make_shared<abstract::AbstractTensor>(dtype, tensor_shape);
// if is parameter always no value. // if is parameter always no value.
if (is_parameter()) { if (is_parameter_) {
auto param_name = param_info()->name(); auto param_name = param_info_->name();
auto ref_key = std::make_shared<RefKey>(param_name); auto ref_key = std::make_shared<RefKey>(param_name);
auto abs_ref_key = ref_key->ToAbstract(); auto abs_ref_key = ref_key->ToAbstract();
abs_tensor = std::make_shared<abstract::AbstractRef>(abs_ref_key, abs_tensor); abs_tensor = std::make_shared<abstract::AbstractRef>(abs_ref_key, abs_tensor);
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import numpy as np import numpy as np
import pytest import pytest
from mindspore import context, Tensor, Parameter, ParameterTuple from mindspore import context, Tensor, Parameter, ParameterTuple, nn
from mindspore._checkparam import _check_str_by_regular from mindspore._checkparam import _check_str_by_regular
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
...@@ -229,3 +229,25 @@ def test_parameter_lazy_init(): ...@@ -229,3 +229,25 @@ def test_parameter_lazy_init():
para.set_parameter_data(initializer('ones', [1, 2], mstype.float32), slice_shape=True) para.set_parameter_data(initializer('ones', [1, 2], mstype.float32), slice_shape=True)
assert np.array_equal(para.default_input.asnumpy(), np.ones((1, 2))) assert np.array_equal(para.default_input.asnumpy(), np.ones((1, 2)))
context.reset_auto_parallel_context() context.reset_auto_parallel_context()
def test_parameter_as_output():
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
initial_input = initializer('One', shape=(2,), dtype=mstype.int32)
updated_input = Tensor([2, 2], mstype.int32)
class Net(nn.Cell):
def __init__(self, initial, updated):
super().__init__()
self.initial = initial
self.updated = updated
self.p = Parameter(self.initial, name="weight")
self.new_p = self.p.init_data()
self.new_p.set_parameter_data(self.updated)
def construct(self):
return self.new_p
net = Net(initial_input, updated_input)
output = net()
assert np.array_equal(output.asnumpy(), np.array([2, 2], np.int32))
context.reset_auto_parallel_context()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册