diff --git a/mindspore/_extends/parse/parser.py b/mindspore/_extends/parse/parser.py index 1b4f76f7764ace4b88b26476d027d0691981f5f0..2c3168bc62911943e8e95307319503dd67825875 100644 --- a/mindspore/_extends/parse/parser.py +++ b/mindspore/_extends/parse/parser.py @@ -147,7 +147,7 @@ def resolve_symbol(namespace, symbol): resolve_ = namespace[symbol] # list and dict is not hashable ,it can not be key for the map, just return the result - if isinstance(resolve_, (list, dict)): + if isinstance(resolve_, (tuple, list, dict)): return resolve_ # dataclass may not be hashable diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index fe1420ad1508f072498a5703e32fe31028b02b3a..0145f4656bc1a85d126f5559d808f2f26178226e 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -642,6 +642,9 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v inputs.push_back(NewValueNode(prim)); size_t size = op_exec_info->op_inputs.size(); + auto const_input_index = prim->get_const_input_indexes(); + bool have_const_input = !const_input_index.empty(); + bool is_const_prim = prim->is_const_prim(); for (size_t i = 0; i < size; i++) { auto obj = op_exec_info->op_inputs[i]; bool op_mask = false; @@ -669,12 +672,13 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v abs = node->abstract(); } MS_LOG(DEBUG) << prim->ToString() << " abs is nullptr " << (abs == nullptr) << " is_const_value " - << prim->is_const_value(); - if (abs == nullptr || prim->is_const_value()) { + << prim->is_const_prim(); + bool is_const_input = have_const_input && std::count(const_input_index.begin(), const_input_index.end(), i); + if (abs == nullptr || is_const_prim || is_const_input) { MS_LOG(DEBUG) << "MakeCnode get node no in map" << id; ValuePtr input_value = PyAttrValue(obj); abs = input_value->ToAbstract(); - if (!prim->is_const_value()) { + if (!is_const_prim && !is_const_input) { auto config = abstract::AbstractBase::kBroadenTensorOnly; abs = abs->Broaden(config); MS_LOG(DEBUG) << "broaden for " << prim->ToString() << " " << config; @@ -885,7 +889,7 @@ py::tuple PynativeExecutor::RunOpInner(const py::args &args) { value_ret[0] = output["value"]; return value_ret; } - if (op_exec_info->py_primitive->is_const_value()) { + if (op_exec_info->py_primitive->is_const_prim()) { py::tuple value_ret(1); value_ret[0] = ""; return value_ret; @@ -1044,7 +1048,7 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) { auto tuple = obj.cast(); // cell((1,2)): support not mix (scalar, tensor) - if (tuple.size() > 0 && !py::isinstance(tuple[0])) { + if (!tuple.empty() && !py::isinstance(tuple[0])) { return MakeValueNode(obj, obj_id); } diff --git a/mindspore/ccsrc/pybind_api/ir/primitive_py.cc b/mindspore/ccsrc/pybind_api/ir/primitive_py.cc index 1d27f1be594f3bca93e6b48c390cc164f31d017c..a2d06376c4c032347bc5ec90739bbc1816de79f2 100644 --- a/mindspore/ccsrc/pybind_api/ir/primitive_py.cc +++ b/mindspore/ccsrc/pybind_api/ir/primitive_py.cc @@ -98,22 +98,22 @@ py::tuple check_bprop_out(const py::object &grads_obj, const py::tuple &py_args) << ", and the value is " << py::cast(grads[i]) << "."; } - py::tuple grad_shape = grads[i].attr("shape"); + py::object arg_dtype = py_args[i].attr("dtype"); py::object grad_dtype = grads[i].attr("dtype"); py::tuple arg_shape = py_args[i].attr("shape"); - py::object arg_dtype = py_args[i].attr("dtype"); + py::tuple grad_shape = grads[i].attr("shape"); + if (!grad_dtype.equal(arg_dtype)) { + MS_EXCEPTION(TypeError) << "When user defines the net bprop, the gradient of the " << i + << "th arg should have the same dtype as the " << i << "th arg, but the " << i + << "th arg dtype is: " << py::cast(arg_dtype) + << ", the gradient dtype is: " << py::cast(grad_dtype) << "."; + } if (!grad_shape.equal(arg_shape)) { MS_EXCEPTION(ValueError) << "When user defines the net bprop, the gradient of the " << i << "th arg should have the same shape as the " << i << "th arg, but the " << i << "th arg shape is: " << py::cast(arg_shape) << ", the gradient shape is: " << py::cast(grad_shape) << "."; } - if (!grad_dtype.is(arg_dtype)) { - MS_EXCEPTION(TypeError) << "When user defines the net bprop, the gradient of the " << i - << "th arg should have the same dtype as the " << i << "th arg, but the " << i - << "th arg dtype is: " << py::cast(arg_dtype) - << ", the gradient dtype is: " << py::cast(grad_dtype) << "."; - } } } } @@ -239,10 +239,7 @@ py::object PrimitivePy::RunPyComputeFunction(const py::tuple &py_args) const { bool PrimitivePy::HasComputeFunction() const { auto func = GetComputeFunction(); - if (py::isinstance(func)) { - return false; - } - return true; + return !py::isinstance(func); } PrimitivePtr PrimitivePy::Clone() { @@ -272,7 +269,9 @@ REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) { .def("add_attr", &PrimitivePy::AddPyAttr, "add primitive attr") .def("get_attr_dict", &PrimitivePy::GetAttrDict, "get primitive attr") .def("set_prim_type", &PrimitivePy::set_prim_type, "Set primitive type.") - .def("set_is_const_value", &PrimitivePy::set_is_const_value, "Set primitive is const value.") + .def("set_const_prim", &PrimitivePy::set_const_prim, "Set primitive is const.") + .def("set_const_input_indexes", &PrimitivePy::set_const_input_indexes, + "Set primitive const input indexes.") .def("set_signatures", &PrimitivePy::set_signatures, "Set primitive inputs signature.") .def("register_hook", &PrimitivePy::set_hook, "Set primitive hook function.") .def("set_instance_name", &PrimitivePy::set_instance_name, "Set primitive instance name."); diff --git a/mindspore/core/ir/primitive.cc b/mindspore/core/ir/primitive.cc index 1e690d7b47b81965729d2a5670ededd5325f30d7..993e0bab83b77534a7da479515c96f3601476593 100644 --- a/mindspore/core/ir/primitive.cc +++ b/mindspore/core/ir/primitive.cc @@ -32,7 +32,7 @@ Primitive::Primitive(const std::string &name, const bool is_base, const PrimType has_signature_(false), prim_type_(prim_type), record_evaluate_add_attr_(false), - is_const_value_(false), + is_const_prim_(false), id_(MakeId()) {} Primitive::Primitive(const Primitive &prim) @@ -43,7 +43,7 @@ Primitive::Primitive(const Primitive &prim) has_signature_(prim.has_signature_), prim_type_(prim.prim_type_), record_evaluate_add_attr_(false), - is_const_value_(false), + is_const_prim_(false), id_(prim.id_) {} abstract::AbstractBasePtr Primitive::ToAbstract() { diff --git a/mindspore/core/ir/primitive.h b/mindspore/core/ir/primitive.h index be1155b40a1b15d77a902f8fb6cacdf98c4b254e..c6525dabd675b5eedaef561ec3edfd7ff014a744 100644 --- a/mindspore/core/ir/primitive.h +++ b/mindspore/core/ir/primitive.h @@ -109,8 +109,12 @@ class Primitive : public Named { bool is_base() const { return is_base_; } virtual BaseRef RunHookFunction(const VectorRef &args) const { MS_LOG(EXCEPTION) << "call a empty function!"; } virtual void CopyHookFunction(const PrimitivePtr &primitive) { MS_LOG(EXCEPTION) << "call a empty function!"; } - void set_is_const_value(bool value) { is_const_value_ = value; } - bool is_const_value() const { return is_const_value_; } + void set_const_prim(bool is_const_prim) { is_const_prim_ = is_const_prim; } + bool is_const_prim() const { return is_const_prim_; } + void set_const_input_indexes(const std::vector &const_input_indexes) { + const_input_indexes_ = const_input_indexes; + } + std::vector &get_const_input_indexes() { return const_input_indexes_; } std::string id() const { return id_; } protected: @@ -123,7 +127,8 @@ class Primitive : public Named { bool has_signature_; PrimType prim_type_; bool record_evaluate_add_attr_; - bool is_const_value_; + bool is_const_prim_; + std::vector const_input_indexes_; std::string id_{""}; }; diff --git a/mindspore/ops/functional.py b/mindspore/ops/functional.py index 13a26b1e0574051b503d226cccad2decc7430800..bbabc03733cb135c42b30564dbf28b9dcd8d946c 100644 --- a/mindspore/ops/functional.py +++ b/mindspore/ops/functional.py @@ -28,7 +28,7 @@ hastype = Primitive('hastype') cast = P.Cast() dtype = P.DType() isconstant = Primitive('is_constant') -isconstant.set_is_const_value(True) +isconstant.set_const_prim(True) issubclass_ = P.IsSubClass() isinstance_ = P.IsInstance() diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index eb066c44b759c16aa38dfe4b51f48037cc4d699d..b289e3eb68213c13c45afcf0be8f04fd14d2f12e 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -1089,7 +1089,7 @@ class InvertPermutation(PrimitiveWithInfer): @prim_attr_register def __init__(self): """init InvertPermutation""" - self.set_is_const_value(True) + self.set_const_prim(True) def __infer__(self, x): x_shp = x['shape'] diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index eb32ea538e71f580b128c4ad95f6849f3264955b..1ccd900b6da29fbe957777349b13ad82fb9abc24 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -2873,6 +2873,7 @@ class MirrorPad(PrimitiveWithInfer): """Init Pad""" validator.check_string('mode', mode, ['REFLECT', 'SYMMETRIC'], self.name) self.mode = mode + self.set_const_input_indexes([1]) def __infer__(self, input_x, paddings): validator.check_subclass("input_x", input_x['dtype'], mstype.tensor, self.name) diff --git a/mindspore/ops/primitive.py b/mindspore/ops/primitive.py index 65d9a7deffcb15c586db9421b4dd37095b61164d..7b2596a885f4c4dc4610298c2246624f991971bd 100644 --- a/mindspore/ops/primitive.py +++ b/mindspore/ops/primitive.py @@ -390,7 +390,7 @@ def constexpr(fn=None, get_instance=True, name=None): def __init__(self): op_name = name if name else fn.__name__ PrimitiveWithInfer.__init__(self, op_name) - self.set_is_const_value(True) + self.set_const_prim(True) def infer_value(self, *args): return fn(*args)