提交 9d8fb786 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5373 fix bug the const input is broadened in PyNative mode

Merge pull request !5373 from zhangbuxue/fix_bug_the_const_input_is_broadened_in_PyNative_mode
......@@ -147,7 +147,7 @@ def resolve_symbol(namespace, symbol):
resolve_ = namespace[symbol]
# list and dict is not hashable ,it can not be key for the map, just return the result
if isinstance(resolve_, (list, dict)):
if isinstance(resolve_, (tuple, list, dict)):
return resolve_
# dataclass may not be hashable
......
......@@ -642,6 +642,9 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v
inputs.push_back(NewValueNode(prim));
size_t size = op_exec_info->op_inputs.size();
auto const_input_index = prim->get_const_input_indexes();
bool have_const_input = !const_input_index.empty();
bool is_const_prim = prim->is_const_prim();
for (size_t i = 0; i < size; i++) {
auto obj = op_exec_info->op_inputs[i];
bool op_mask = false;
......@@ -669,12 +672,13 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v
abs = node->abstract();
}
MS_LOG(DEBUG) << prim->ToString() << " abs is nullptr " << (abs == nullptr) << " is_const_value "
<< prim->is_const_value();
if (abs == nullptr || prim->is_const_value()) {
<< prim->is_const_prim();
bool is_const_input = have_const_input && std::count(const_input_index.begin(), const_input_index.end(), i);
if (abs == nullptr || is_const_prim || is_const_input) {
MS_LOG(DEBUG) << "MakeCnode get node no in map" << id;
ValuePtr input_value = PyAttrValue(obj);
abs = input_value->ToAbstract();
if (!prim->is_const_value()) {
if (!is_const_prim && !is_const_input) {
auto config = abstract::AbstractBase::kBroadenTensorOnly;
abs = abs->Broaden(config);
MS_LOG(DEBUG) << "broaden for " << prim->ToString() << " " << config;
......@@ -885,7 +889,7 @@ py::tuple PynativeExecutor::RunOpInner(const py::args &args) {
value_ret[0] = output["value"];
return value_ret;
}
if (op_exec_info->py_primitive->is_const_value()) {
if (op_exec_info->py_primitive->is_const_prim()) {
py::tuple value_ret(1);
value_ret[0] = "";
return value_ret;
......@@ -1044,7 +1048,7 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) {
auto tuple = obj.cast<py::tuple>();
// cell((1,2)): support not mix (scalar, tensor)
if (tuple.size() > 0 && !py::isinstance<tensor::Tensor>(tuple[0])) {
if (!tuple.empty() && !py::isinstance<tensor::Tensor>(tuple[0])) {
return MakeValueNode(obj, obj_id);
}
......
......@@ -98,22 +98,22 @@ py::tuple check_bprop_out(const py::object &grads_obj, const py::tuple &py_args)
<< ", and the value is " << py::cast<py::str>(grads[i]) << ".";
}
py::tuple grad_shape = grads[i].attr("shape");
py::object arg_dtype = py_args[i].attr("dtype");
py::object grad_dtype = grads[i].attr("dtype");
py::tuple arg_shape = py_args[i].attr("shape");
py::object arg_dtype = py_args[i].attr("dtype");
py::tuple grad_shape = grads[i].attr("shape");
if (!grad_dtype.equal(arg_dtype)) {
MS_EXCEPTION(TypeError) << "When user defines the net bprop, the gradient of the " << i
<< "th arg should have the same dtype as the " << i << "th arg, but the " << i
<< "th arg dtype is: " << py::cast<py::str>(arg_dtype)
<< ", the gradient dtype is: " << py::cast<py::str>(grad_dtype) << ".";
}
if (!grad_shape.equal(arg_shape)) {
MS_EXCEPTION(ValueError) << "When user defines the net bprop, the gradient of the " << i
<< "th arg should have the same shape as the " << i << "th arg, but the " << i
<< "th arg shape is: " << py::cast<py::str>(arg_shape)
<< ", the gradient shape is: " << py::cast<py::str>(grad_shape) << ".";
}
if (!grad_dtype.is(arg_dtype)) {
MS_EXCEPTION(TypeError) << "When user defines the net bprop, the gradient of the " << i
<< "th arg should have the same dtype as the " << i << "th arg, but the " << i
<< "th arg dtype is: " << py::cast<py::str>(arg_dtype)
<< ", the gradient dtype is: " << py::cast<py::str>(grad_dtype) << ".";
}
}
}
}
......@@ -239,10 +239,7 @@ py::object PrimitivePy::RunPyComputeFunction(const py::tuple &py_args) const {
bool PrimitivePy::HasComputeFunction() const {
auto func = GetComputeFunction();
if (py::isinstance<py::none>(func)) {
return false;
}
return true;
return !py::isinstance<py::none>(func);
}
PrimitivePtr PrimitivePy::Clone() {
......@@ -272,7 +269,9 @@ REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) {
.def("add_attr", &PrimitivePy::AddPyAttr, "add primitive attr")
.def("get_attr_dict", &PrimitivePy::GetAttrDict, "get primitive attr")
.def("set_prim_type", &PrimitivePy::set_prim_type, "Set primitive type.")
.def("set_is_const_value", &PrimitivePy::set_is_const_value, "Set primitive is const value.")
.def("set_const_prim", &PrimitivePy::set_const_prim, "Set primitive is const.")
.def("set_const_input_indexes", &PrimitivePy::set_const_input_indexes,
"Set primitive const input indexes.")
.def("set_signatures", &PrimitivePy::set_signatures, "Set primitive inputs signature.")
.def("register_hook", &PrimitivePy::set_hook, "Set primitive hook function.")
.def("set_instance_name", &PrimitivePy::set_instance_name, "Set primitive instance name.");
......
......@@ -32,7 +32,7 @@ Primitive::Primitive(const std::string &name, const bool is_base, const PrimType
has_signature_(false),
prim_type_(prim_type),
record_evaluate_add_attr_(false),
is_const_value_(false),
is_const_prim_(false),
id_(MakeId()) {}
Primitive::Primitive(const Primitive &prim)
......@@ -43,7 +43,7 @@ Primitive::Primitive(const Primitive &prim)
has_signature_(prim.has_signature_),
prim_type_(prim.prim_type_),
record_evaluate_add_attr_(false),
is_const_value_(false),
is_const_prim_(false),
id_(prim.id_) {}
abstract::AbstractBasePtr Primitive::ToAbstract() {
......
......@@ -109,8 +109,12 @@ class Primitive : public Named {
bool is_base() const { return is_base_; }
virtual BaseRef RunHookFunction(const VectorRef &args) const { MS_LOG(EXCEPTION) << "call a empty function!"; }
virtual void CopyHookFunction(const PrimitivePtr &primitive) { MS_LOG(EXCEPTION) << "call a empty function!"; }
void set_is_const_value(bool value) { is_const_value_ = value; }
bool is_const_value() const { return is_const_value_; }
void set_const_prim(bool is_const_prim) { is_const_prim_ = is_const_prim; }
bool is_const_prim() const { return is_const_prim_; }
void set_const_input_indexes(const std::vector<size_t> &const_input_indexes) {
const_input_indexes_ = const_input_indexes;
}
std::vector<size_t> &get_const_input_indexes() { return const_input_indexes_; }
std::string id() const { return id_; }
protected:
......@@ -123,7 +127,8 @@ class Primitive : public Named {
bool has_signature_;
PrimType prim_type_;
bool record_evaluate_add_attr_;
bool is_const_value_;
bool is_const_prim_;
std::vector<size_t> const_input_indexes_;
std::string id_{""};
};
......
......@@ -28,7 +28,7 @@ hastype = Primitive('hastype')
cast = P.Cast()
dtype = P.DType()
isconstant = Primitive('is_constant')
isconstant.set_is_const_value(True)
isconstant.set_const_prim(True)
issubclass_ = P.IsSubClass()
isinstance_ = P.IsInstance()
......
......@@ -1089,7 +1089,7 @@ class InvertPermutation(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
"""init InvertPermutation"""
self.set_is_const_value(True)
self.set_const_prim(True)
def __infer__(self, x):
x_shp = x['shape']
......
......@@ -2873,6 +2873,7 @@ class MirrorPad(PrimitiveWithInfer):
"""Init Pad"""
validator.check_string('mode', mode, ['REFLECT', 'SYMMETRIC'], self.name)
self.mode = mode
self.set_const_input_indexes([1])
def __infer__(self, input_x, paddings):
validator.check_subclass("input_x", input_x['dtype'], mstype.tensor, self.name)
......
......@@ -390,7 +390,7 @@ def constexpr(fn=None, get_instance=True, name=None):
def __init__(self):
op_name = name if name else fn.__name__
PrimitiveWithInfer.__init__(self, op_name)
self.set_is_const_value(True)
self.set_const_prim(True)
def infer_value(self, *args):
return fn(*args)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册