提交 d5255fe3 编写于 作者: Z zhousiyi

fix assign used in while

上级 bc30576a
...@@ -314,7 +314,7 @@ std::map<std::string, std::pair<PrimitivePyPtr, std::string>> ExecutorPy::FetchI ...@@ -314,7 +314,7 @@ std::map<std::string, std::pair<PrimitivePyPtr, std::string>> ExecutorPy::FetchI
auto weight_name = weight_node->cast<ParameterPtr>()->name(); auto weight_name = weight_node->cast<ParameterPtr>()->name();
// find the fakequant from input // find the fakequant from input
int count = 0; int count = 0;
int max_depth = 5; const int max_depth = 5;
while (!is_quant_cnode(x)) { while (!is_quant_cnode(x)) {
if (count >= max_depth) { if (count >= max_depth) {
break; break;
......
...@@ -429,6 +429,11 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) { ...@@ -429,6 +429,11 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) {
if (other_tensor == nullptr) { if (other_tensor == nullptr) {
MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString(); MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString();
} }
if (*this == *other) {
if (sparse_grad() == other->sparse_grad()) {
return shared_from_base<AbstractBase>();
}
}
auto element = element_->Join(other_tensor->element_); auto element = element_->Join(other_tensor->element_);
auto shape = ShapeJoin(this->shape(), other_tensor->shape()); auto shape = ShapeJoin(this->shape(), other_tensor->shape());
auto ret = std::make_shared<AbstractTensor>(element, shape); auto ret = std::make_shared<AbstractTensor>(element, shape);
...@@ -812,6 +817,21 @@ bool AbstractRef::operator==(const AbstractBase &other) const { ...@@ -812,6 +817,21 @@ bool AbstractRef::operator==(const AbstractBase &other) const {
return false; return false;
} }
AbstractBasePtr AbstractRef::Join(const AbstractBasePtr &other) {
auto other_ref = other->cast<AbstractRefPtr>();
if (other_ref == nullptr) {
MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString();
}
if (*this == *other) {
return shared_from_base<AbstractBase>();
}
auto ref_key = ref_key_->Join(other_ref->ref_key_);
auto ref = ref_->Join(other_ref->ref());
auto ref_origin = ref_origin_->Join(other_ref->ref_origin_);
return std::make_shared<AbstractRef>(ref_key, ref, ref_origin);
}
std::string AbstractRef::ToString() const { std::string AbstractRef::ToString() const {
std::ostringstream buffer; std::ostringstream buffer;
buffer << type_name() << "(" buffer << type_name() << "("
......
...@@ -564,6 +564,7 @@ class AbstractRef : public AbstractBase { ...@@ -564,6 +564,7 @@ class AbstractRef : public AbstractBase {
AbstractBasePtr Broaden() const override { AbstractBasePtr Broaden() const override {
return std::make_shared<AbstractRef>(ref_key_->Broaden(), ref_->Broaden(), ref_origin_->Broaden()); return std::make_shared<AbstractRef>(ref_key_->Broaden(), ref_->Broaden(), ref_origin_->Broaden());
} }
AbstractBasePtr Join(const AbstractBasePtr &other) override;
std::size_t hash() const override { std::size_t hash() const override {
return ref_key_->hash() ^ ref_->hash() ^ ref_origin_->hash() ^ (std::hash<uint32_t>{}(this->tid()) << 1); return ref_key_->hash() ^ ref_->hash() ^ ref_origin_->hash() ^ (std::hash<uint32_t>{}(this->tid()) << 1);
} }
......
...@@ -166,6 +166,7 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa ...@@ -166,6 +166,7 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa
// If there is loop variant, all arguments need to be broaden to avoid wrong constant propagation. // If there is loop variant, all arguments need to be broaden to avoid wrong constant propagation.
if (!(joined_args_spec_list == args_spec_list)) { if (!(joined_args_spec_list == args_spec_list)) {
func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag.";
} }
return joined_args_spec_list; return joined_args_spec_list;
} }
...@@ -179,6 +180,7 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa ...@@ -179,6 +180,7 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa
if (!(joined_args_spec_list == args_spec_list)) { if (!(joined_args_spec_list == args_spec_list)) {
trace_.push_back(joined_args_spec_list); trace_.push_back(joined_args_spec_list);
func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag.";
} }
MS_LOG(DEBUG) << "Joined eval args: " << ::mindspore::ToString(joined_args_spec_list); MS_LOG(DEBUG) << "Joined eval args: " << ::mindspore::ToString(joined_args_spec_list);
return joined_args_spec_list; return joined_args_spec_list;
......
...@@ -16,29 +16,54 @@ ...@@ -16,29 +16,54 @@
import numpy as np import numpy as np
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor, context
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
import mindspore.ops.operations as op
def test_net_infer():
""" test_net_infer """
class Net(nn.Cell):
""" Net definition """
class Net(nn.Cell): def __init__(self):
""" Net definition """ super(Net, self).__init__()
self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal')
self.bn = nn.BatchNorm2d(64)
self.fc = nn.Dense(64, 10)
self.relu = nn.ReLU()
self.flatten = nn.Flatten()
def __init__(self): def construct(self, x):
super(Net, self).__init__() x = self.conv(x)
self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal') x = self.relu(x)
self.bn = nn.BatchNorm2d(64) x = self.flatten(x)
self.fc = nn.Dense(64, 10) out = self.fc(x)
self.relu = nn.ReLU() return out
self.flatten = nn.Flatten() Tensor(np.random.randint(0, 255, [1, 3, 224, 224]))
Net()
def construct(self, x):
x = self.conv(x)
x = self.relu(x)
x = self.flatten(x)
out = self.fc(x)
return out
def test_assign_in_while():
context.set_context(mode=context.GRAPH_MODE)
class Net(nn.Cell):
def __init__(self, input_shape):
super().__init__()
self.assign = op.Assign()
self.inputdata = Parameter(initializer(1, input_shape), name="global_step")
def test_net_infer(): def construct(self, x, y, z):
""" test_net_infer """ out = z
Tensor(np.random.randint(0, 255, [1, 3, 224, 224])) while x < y:
Net() inputdata = self.inputdata
x = x + 1
out = self.assign(inputdata, z)
return out
x = Tensor(np.array(1).astype(np.int32))
y = Tensor(np.array(3).astype(np.int32))
input_shape = (1024, 512)
z = Tensor(np.random.randn(*input_shape).astype(np.float32))
net = Net(input_shape)
ret = net(x, y, z)
assert ret == z
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册