提交 9d8fb786 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5373 fix bug the const input is broadened in PyNative mode

Merge pull request !5373 from zhangbuxue/fix_bug_the_const_input_is_broadened_in_PyNative_mode
...@@ -147,7 +147,7 @@ def resolve_symbol(namespace, symbol): ...@@ -147,7 +147,7 @@ def resolve_symbol(namespace, symbol):
resolve_ = namespace[symbol] resolve_ = namespace[symbol]
# list and dict is not hashable ,it can not be key for the map, just return the result # list and dict is not hashable ,it can not be key for the map, just return the result
if isinstance(resolve_, (list, dict)): if isinstance(resolve_, (tuple, list, dict)):
return resolve_ return resolve_
# dataclass may not be hashable # dataclass may not be hashable
......
...@@ -642,6 +642,9 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v ...@@ -642,6 +642,9 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v
inputs.push_back(NewValueNode(prim)); inputs.push_back(NewValueNode(prim));
size_t size = op_exec_info->op_inputs.size(); size_t size = op_exec_info->op_inputs.size();
auto const_input_index = prim->get_const_input_indexes();
bool have_const_input = !const_input_index.empty();
bool is_const_prim = prim->is_const_prim();
for (size_t i = 0; i < size; i++) { for (size_t i = 0; i < size; i++) {
auto obj = op_exec_info->op_inputs[i]; auto obj = op_exec_info->op_inputs[i];
bool op_mask = false; bool op_mask = false;
...@@ -669,12 +672,13 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v ...@@ -669,12 +672,13 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v
abs = node->abstract(); abs = node->abstract();
} }
MS_LOG(DEBUG) << prim->ToString() << " abs is nullptr " << (abs == nullptr) << " is_const_value " MS_LOG(DEBUG) << prim->ToString() << " abs is nullptr " << (abs == nullptr) << " is_const_value "
<< prim->is_const_value(); << prim->is_const_prim();
if (abs == nullptr || prim->is_const_value()) { bool is_const_input = have_const_input && std::count(const_input_index.begin(), const_input_index.end(), i);
if (abs == nullptr || is_const_prim || is_const_input) {
MS_LOG(DEBUG) << "MakeCnode get node no in map" << id; MS_LOG(DEBUG) << "MakeCnode get node no in map" << id;
ValuePtr input_value = PyAttrValue(obj); ValuePtr input_value = PyAttrValue(obj);
abs = input_value->ToAbstract(); abs = input_value->ToAbstract();
if (!prim->is_const_value()) { if (!is_const_prim && !is_const_input) {
auto config = abstract::AbstractBase::kBroadenTensorOnly; auto config = abstract::AbstractBase::kBroadenTensorOnly;
abs = abs->Broaden(config); abs = abs->Broaden(config);
MS_LOG(DEBUG) << "broaden for " << prim->ToString() << " " << config; MS_LOG(DEBUG) << "broaden for " << prim->ToString() << " " << config;
...@@ -885,7 +889,7 @@ py::tuple PynativeExecutor::RunOpInner(const py::args &args) { ...@@ -885,7 +889,7 @@ py::tuple PynativeExecutor::RunOpInner(const py::args &args) {
value_ret[0] = output["value"]; value_ret[0] = output["value"];
return value_ret; return value_ret;
} }
if (op_exec_info->py_primitive->is_const_value()) { if (op_exec_info->py_primitive->is_const_prim()) {
py::tuple value_ret(1); py::tuple value_ret(1);
value_ret[0] = ""; value_ret[0] = "";
return value_ret; return value_ret;
...@@ -1044,7 +1048,7 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) { ...@@ -1044,7 +1048,7 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) {
auto tuple = obj.cast<py::tuple>(); auto tuple = obj.cast<py::tuple>();
// cell((1,2)): support not mix (scalar, tensor) // cell((1,2)): support not mix (scalar, tensor)
if (tuple.size() > 0 && !py::isinstance<tensor::Tensor>(tuple[0])) { if (!tuple.empty() && !py::isinstance<tensor::Tensor>(tuple[0])) {
return MakeValueNode(obj, obj_id); return MakeValueNode(obj, obj_id);
} }
......
...@@ -98,22 +98,22 @@ py::tuple check_bprop_out(const py::object &grads_obj, const py::tuple &py_args) ...@@ -98,22 +98,22 @@ py::tuple check_bprop_out(const py::object &grads_obj, const py::tuple &py_args)
<< ", and the value is " << py::cast<py::str>(grads[i]) << "."; << ", and the value is " << py::cast<py::str>(grads[i]) << ".";
} }
py::tuple grad_shape = grads[i].attr("shape"); py::object arg_dtype = py_args[i].attr("dtype");
py::object grad_dtype = grads[i].attr("dtype"); py::object grad_dtype = grads[i].attr("dtype");
py::tuple arg_shape = py_args[i].attr("shape"); py::tuple arg_shape = py_args[i].attr("shape");
py::object arg_dtype = py_args[i].attr("dtype"); py::tuple grad_shape = grads[i].attr("shape");
if (!grad_dtype.equal(arg_dtype)) {
MS_EXCEPTION(TypeError) << "When user defines the net bprop, the gradient of the " << i
<< "th arg should have the same dtype as the " << i << "th arg, but the " << i
<< "th arg dtype is: " << py::cast<py::str>(arg_dtype)
<< ", the gradient dtype is: " << py::cast<py::str>(grad_dtype) << ".";
}
if (!grad_shape.equal(arg_shape)) { if (!grad_shape.equal(arg_shape)) {
MS_EXCEPTION(ValueError) << "When user defines the net bprop, the gradient of the " << i MS_EXCEPTION(ValueError) << "When user defines the net bprop, the gradient of the " << i
<< "th arg should have the same shape as the " << i << "th arg, but the " << i << "th arg should have the same shape as the " << i << "th arg, but the " << i
<< "th arg shape is: " << py::cast<py::str>(arg_shape) << "th arg shape is: " << py::cast<py::str>(arg_shape)
<< ", the gradient shape is: " << py::cast<py::str>(grad_shape) << "."; << ", the gradient shape is: " << py::cast<py::str>(grad_shape) << ".";
} }
if (!grad_dtype.is(arg_dtype)) {
MS_EXCEPTION(TypeError) << "When user defines the net bprop, the gradient of the " << i
<< "th arg should have the same dtype as the " << i << "th arg, but the " << i
<< "th arg dtype is: " << py::cast<py::str>(arg_dtype)
<< ", the gradient dtype is: " << py::cast<py::str>(grad_dtype) << ".";
}
} }
} }
} }
...@@ -239,10 +239,7 @@ py::object PrimitivePy::RunPyComputeFunction(const py::tuple &py_args) const { ...@@ -239,10 +239,7 @@ py::object PrimitivePy::RunPyComputeFunction(const py::tuple &py_args) const {
bool PrimitivePy::HasComputeFunction() const { bool PrimitivePy::HasComputeFunction() const {
auto func = GetComputeFunction(); auto func = GetComputeFunction();
if (py::isinstance<py::none>(func)) { return !py::isinstance<py::none>(func);
return false;
}
return true;
} }
PrimitivePtr PrimitivePy::Clone() { PrimitivePtr PrimitivePy::Clone() {
...@@ -272,7 +269,9 @@ REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) { ...@@ -272,7 +269,9 @@ REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) {
.def("add_attr", &PrimitivePy::AddPyAttr, "add primitive attr") .def("add_attr", &PrimitivePy::AddPyAttr, "add primitive attr")
.def("get_attr_dict", &PrimitivePy::GetAttrDict, "get primitive attr") .def("get_attr_dict", &PrimitivePy::GetAttrDict, "get primitive attr")
.def("set_prim_type", &PrimitivePy::set_prim_type, "Set primitive type.") .def("set_prim_type", &PrimitivePy::set_prim_type, "Set primitive type.")
.def("set_is_const_value", &PrimitivePy::set_is_const_value, "Set primitive is const value.") .def("set_const_prim", &PrimitivePy::set_const_prim, "Set primitive is const.")
.def("set_const_input_indexes", &PrimitivePy::set_const_input_indexes,
"Set primitive const input indexes.")
.def("set_signatures", &PrimitivePy::set_signatures, "Set primitive inputs signature.") .def("set_signatures", &PrimitivePy::set_signatures, "Set primitive inputs signature.")
.def("register_hook", &PrimitivePy::set_hook, "Set primitive hook function.") .def("register_hook", &PrimitivePy::set_hook, "Set primitive hook function.")
.def("set_instance_name", &PrimitivePy::set_instance_name, "Set primitive instance name."); .def("set_instance_name", &PrimitivePy::set_instance_name, "Set primitive instance name.");
......
...@@ -32,7 +32,7 @@ Primitive::Primitive(const std::string &name, const bool is_base, const PrimType ...@@ -32,7 +32,7 @@ Primitive::Primitive(const std::string &name, const bool is_base, const PrimType
has_signature_(false), has_signature_(false),
prim_type_(prim_type), prim_type_(prim_type),
record_evaluate_add_attr_(false), record_evaluate_add_attr_(false),
is_const_value_(false), is_const_prim_(false),
id_(MakeId()) {} id_(MakeId()) {}
Primitive::Primitive(const Primitive &prim) Primitive::Primitive(const Primitive &prim)
...@@ -43,7 +43,7 @@ Primitive::Primitive(const Primitive &prim) ...@@ -43,7 +43,7 @@ Primitive::Primitive(const Primitive &prim)
has_signature_(prim.has_signature_), has_signature_(prim.has_signature_),
prim_type_(prim.prim_type_), prim_type_(prim.prim_type_),
record_evaluate_add_attr_(false), record_evaluate_add_attr_(false),
is_const_value_(false), is_const_prim_(false),
id_(prim.id_) {} id_(prim.id_) {}
abstract::AbstractBasePtr Primitive::ToAbstract() { abstract::AbstractBasePtr Primitive::ToAbstract() {
......
...@@ -109,8 +109,12 @@ class Primitive : public Named { ...@@ -109,8 +109,12 @@ class Primitive : public Named {
bool is_base() const { return is_base_; } bool is_base() const { return is_base_; }
virtual BaseRef RunHookFunction(const VectorRef &args) const { MS_LOG(EXCEPTION) << "call a empty function!"; } virtual BaseRef RunHookFunction(const VectorRef &args) const { MS_LOG(EXCEPTION) << "call a empty function!"; }
virtual void CopyHookFunction(const PrimitivePtr &primitive) { MS_LOG(EXCEPTION) << "call a empty function!"; } virtual void CopyHookFunction(const PrimitivePtr &primitive) { MS_LOG(EXCEPTION) << "call a empty function!"; }
void set_is_const_value(bool value) { is_const_value_ = value; } void set_const_prim(bool is_const_prim) { is_const_prim_ = is_const_prim; }
bool is_const_value() const { return is_const_value_; } bool is_const_prim() const { return is_const_prim_; }
void set_const_input_indexes(const std::vector<size_t> &const_input_indexes) {
const_input_indexes_ = const_input_indexes;
}
std::vector<size_t> &get_const_input_indexes() { return const_input_indexes_; }
std::string id() const { return id_; } std::string id() const { return id_; }
protected: protected:
...@@ -123,7 +127,8 @@ class Primitive : public Named { ...@@ -123,7 +127,8 @@ class Primitive : public Named {
bool has_signature_; bool has_signature_;
PrimType prim_type_; PrimType prim_type_;
bool record_evaluate_add_attr_; bool record_evaluate_add_attr_;
bool is_const_value_; bool is_const_prim_;
std::vector<size_t> const_input_indexes_;
std::string id_{""}; std::string id_{""};
}; };
......
...@@ -28,7 +28,7 @@ hastype = Primitive('hastype') ...@@ -28,7 +28,7 @@ hastype = Primitive('hastype')
cast = P.Cast() cast = P.Cast()
dtype = P.DType() dtype = P.DType()
isconstant = Primitive('is_constant') isconstant = Primitive('is_constant')
isconstant.set_is_const_value(True) isconstant.set_const_prim(True)
issubclass_ = P.IsSubClass() issubclass_ = P.IsSubClass()
isinstance_ = P.IsInstance() isinstance_ = P.IsInstance()
......
...@@ -1089,7 +1089,7 @@ class InvertPermutation(PrimitiveWithInfer): ...@@ -1089,7 +1089,7 @@ class InvertPermutation(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
"""init InvertPermutation""" """init InvertPermutation"""
self.set_is_const_value(True) self.set_const_prim(True)
def __infer__(self, x): def __infer__(self, x):
x_shp = x['shape'] x_shp = x['shape']
......
...@@ -2873,6 +2873,7 @@ class MirrorPad(PrimitiveWithInfer): ...@@ -2873,6 +2873,7 @@ class MirrorPad(PrimitiveWithInfer):
"""Init Pad""" """Init Pad"""
validator.check_string('mode', mode, ['REFLECT', 'SYMMETRIC'], self.name) validator.check_string('mode', mode, ['REFLECT', 'SYMMETRIC'], self.name)
self.mode = mode self.mode = mode
self.set_const_input_indexes([1])
def __infer__(self, input_x, paddings): def __infer__(self, input_x, paddings):
validator.check_subclass("input_x", input_x['dtype'], mstype.tensor, self.name) validator.check_subclass("input_x", input_x['dtype'], mstype.tensor, self.name)
......
...@@ -390,7 +390,7 @@ def constexpr(fn=None, get_instance=True, name=None): ...@@ -390,7 +390,7 @@ def constexpr(fn=None, get_instance=True, name=None):
def __init__(self): def __init__(self):
op_name = name if name else fn.__name__ op_name = name if name else fn.__name__
PrimitiveWithInfer.__init__(self, op_name) PrimitiveWithInfer.__init__(self, op_name)
self.set_is_const_value(True) self.set_const_prim(True)
def infer_value(self, *args): def infer_value(self, *args):
return fn(*args) return fn(*args)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册