diff --git a/mindspore/ccsrc/pipeline/pipeline.cc b/mindspore/ccsrc/pipeline/pipeline.cc index d346c980eab7ee008ed280116530bc2c44cd19ee..e477a147be92b73f6cec290c154aaf59b0f84a69 100644 --- a/mindspore/ccsrc/pipeline/pipeline.cc +++ b/mindspore/ccsrc/pipeline/pipeline.cc @@ -314,7 +314,7 @@ std::map> ExecutorPy::FetchI auto weight_name = weight_node->cast()->name(); // find the fakequant from input int count = 0; - int max_depth = 5; + const int max_depth = 5; while (!is_quant_cnode(x)) { if (count >= max_depth) { break; diff --git a/mindspore/ccsrc/pipeline/static_analysis/abstract_value.cc b/mindspore/ccsrc/pipeline/static_analysis/abstract_value.cc index 86bfecf14bc7f61c3b3b6c8b2ee0b259ae432ea2..b59545e5ae64cf81f7f41ca3d99a9946586edd3c 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/abstract_value.cc +++ b/mindspore/ccsrc/pipeline/static_analysis/abstract_value.cc @@ -451,6 +451,11 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) { if (other_tensor == nullptr) { MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString(); } + if (*this == *other) { + if (sparse_grad() == other->sparse_grad()) { + return shared_from_base(); + } + } auto element = element_->Join(other_tensor->element_); auto shape = ShapeJoin(this->shape(), other_tensor->shape()); auto ret = std::make_shared(element, shape); @@ -830,6 +835,21 @@ bool AbstractRef::operator==(const AbstractBase &other) const { return false; } +AbstractBasePtr AbstractRef::Join(const AbstractBasePtr &other) { + auto other_ref = other->cast(); + if (other_ref == nullptr) { + MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString(); + } + if (*this == *other) { + return shared_from_base(); + } + auto ref_key = ref_key_->Join(other_ref->ref_key_); + auto ref = ref_->Join(other_ref->ref()); + auto ref_origin = ref_origin_->Join(other_ref->ref_origin_); + + return std::make_shared(ref_key, ref, ref_origin); +} + std::string AbstractRef::ToString() const { std::ostringstream buffer; buffer << type_name() << "(" diff --git a/mindspore/ccsrc/pipeline/static_analysis/abstract_value.h b/mindspore/ccsrc/pipeline/static_analysis/abstract_value.h index a5b4acff4520d36f0c380bc1a0dee3cc8e0743f5..3981a6eb2316d474b3a9d25b2c1649e0a2e58c23 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/abstract_value.h +++ b/mindspore/ccsrc/pipeline/static_analysis/abstract_value.h @@ -578,6 +578,7 @@ class AbstractRef : public AbstractBase { AbstractBasePtr Broaden() const override { return std::make_shared(ref_key_->Broaden(), ref_->Broaden(), ref_origin_->Broaden()); } + AbstractBasePtr Join(const AbstractBasePtr &other) override; std::size_t hash() const override { return ref_key_->hash() ^ ref_->hash() ^ ref_origin_->hash() ^ (std::hash{}(this->tid()) << 1); } diff --git a/mindspore/ccsrc/pipeline/static_analysis/evaluator.cc b/mindspore/ccsrc/pipeline/static_analysis/evaluator.cc index c9b1ce4f937256b5e046fa94a286535982f811e7..34ecfc89808f9fa19e8ceabd8327ccd39d597dcb 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/evaluator.cc +++ b/mindspore/ccsrc/pipeline/static_analysis/evaluator.cc @@ -166,6 +166,7 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa // If there is loop variant, all arguments need to be broaden to avoid wrong constant propagation. if (!(joined_args_spec_list == args_spec_list)) { func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); + MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag."; } return joined_args_spec_list; } @@ -179,6 +180,7 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa if (!(joined_args_spec_list == args_spec_list)) { trace_.push_back(joined_args_spec_list); func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); + MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag."; } MS_LOG(DEBUG) << "Joined eval args: " << ::mindspore::ToString(joined_args_spec_list); return joined_args_spec_list; diff --git a/tests/ut/python/pipeline/infer/test_net_infer.py b/tests/ut/python/pipeline/infer/test_net_infer.py index 003aad827df6a766284291a36ac80d81a5ff12b5..6b32a7617d204e10c5422d928c04295cded55515 100644 --- a/tests/ut/python/pipeline/infer/test_net_infer.py +++ b/tests/ut/python/pipeline/infer/test_net_infer.py @@ -16,29 +16,54 @@ import numpy as np import mindspore.nn as nn -from mindspore import Tensor +from mindspore import Tensor, context +from mindspore.common.parameter import Parameter +from mindspore.common.initializer import initializer +import mindspore.ops.operations as op +def test_net_infer(): + """ test_net_infer """ + class Net(nn.Cell): + """ Net definition """ -class Net(nn.Cell): - """ Net definition """ + def __init__(self): + super(Net, self).__init__() + self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal') + self.bn = nn.BatchNorm2d(64) + self.fc = nn.Dense(64, 10) + self.relu = nn.ReLU() + self.flatten = nn.Flatten() - def __init__(self): - super(Net, self).__init__() - self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal') - self.bn = nn.BatchNorm2d(64) - self.fc = nn.Dense(64, 10) - self.relu = nn.ReLU() - self.flatten = nn.Flatten() + def construct(self, x): + x = self.conv(x) + x = self.relu(x) + x = self.flatten(x) + out = self.fc(x) + return out + Tensor(np.random.randint(0, 255, [1, 3, 224, 224])) + Net() - def construct(self, x): - x = self.conv(x) - x = self.relu(x) - x = self.flatten(x) - out = self.fc(x) - return out +def test_assign_in_while(): + context.set_context(mode=context.GRAPH_MODE) + class Net(nn.Cell): + def __init__(self, input_shape): + super().__init__() + self.assign = op.Assign() + self.inputdata = Parameter(initializer(1, input_shape), name="global_step") -def test_net_infer(): - """ test_net_infer """ - Tensor(np.random.randint(0, 255, [1, 3, 224, 224])) - Net() + def construct(self, x, y, z): + out = z + while x < y: + inputdata = self.inputdata + x = x + 1 + out = self.assign(inputdata, z) + return out + + x = Tensor(np.array(1).astype(np.int32)) + y = Tensor(np.array(3).astype(np.int32)) + input_shape = (1024, 512) + z = Tensor(np.random.randn(*input_shape).astype(np.float32)) + net = Net(input_shape) + ret = net(x, y, z) + assert ret == z