From 02d6e3a43a4e9b1ec8cffd8daa6570e0822b9a94 Mon Sep 17 00:00:00 2001 From: buxue Date: Sat, 8 Aug 2020 16:29:34 +0800 Subject: [PATCH] fix bugs --- mindspore/_extends/parse/__init__.py | 4 +-- mindspore/_extends/parse/parser.py | 6 ++++ .../frontend/operator/composite/composite.cc | 6 ++-- .../ccsrc/pipeline/jit/parse/parse_base.h | 1 + .../pipeline/jit/static_analysis/prim.cc | 6 ++-- .../pipeline/pynative/pynative_execute.cc | 35 +++++++++++++------ .../pipeline/pynative/pynative_execute.h | 3 +- mindspore/ccsrc/utils/primitive_py.cc | 11 ++++-- mindspore/common/tensor.py | 8 +++-- mindspore/ops/_grad/grad_array_ops.py | 2 +- 10 files changed, 58 insertions(+), 24 deletions(-) diff --git a/mindspore/_extends/parse/__init__.py b/mindspore/_extends/parse/__init__.py index 10a991c1e..cd13d329a 100644 --- a/mindspore/_extends/parse/__init__.py +++ b/mindspore/_extends/parse/__init__.py @@ -22,7 +22,7 @@ from .parser import (Parser, create_obj_instance, generate_scope, get_dataclass_attributes, get_dataclass_methods, get_obj_id, get_module_namespace, get_obj_type, get_object_key, get_parse_method_of_class, get_scope_name, - is_class_member, parse_cb, resolve_symbol) + is_class_member, parse_cb, resolve_symbol, convert_to_ms_tensor) from .serialize import * __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol', @@ -30,4 +30,4 @@ __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'get_obj_type', 'get_obj_id', 'create_obj_instance', 'get_module_namespace', 'get_class_member_namespace_symbol', 'get_obj_id', 'Parser', 'get_dataclass_attributes', 'get_dataclass_methods', 'dump_obj', 'load_obj', 'get_dataclass_methods', 'get_scope_name', - 'create_slice_obj'] + 'create_slice_obj', 'convert_to_ms_tensor'] diff --git a/mindspore/_extends/parse/parser.py b/mindspore/_extends/parse/parser.py index 695f449f8..30f48f826 100644 --- a/mindspore/_extends/parse/parser.py +++ b/mindspore/_extends/parse/parser.py @@ -25,6 +25,7 @@ from dataclasses import is_dataclass import asttokens import mindspore.nn as nn from mindspore import log as logger +from mindspore import Tensor as MsTensor from mindspore import ops from mindspore.common.dtype import pytype_to_dtype from mindspore.common.api import _MindSporeFunction @@ -316,6 +317,11 @@ def get_dataclass_methods(cls): return methods +def convert_to_ms_tensor(data): + """Convert C++ tensor to mindspore tensor.""" + return MsTensor(data) + + class Parser: """ Parser python code to ast tree. diff --git a/mindspore/ccsrc/frontend/operator/composite/composite.cc b/mindspore/ccsrc/frontend/operator/composite/composite.cc index 262cb789c..b7fa2cf67 100644 --- a/mindspore/ccsrc/frontend/operator/composite/composite.cc +++ b/mindspore/ccsrc/frontend/operator/composite/composite.cc @@ -929,7 +929,7 @@ void GenerateTupleSliceParameter(const AbstractTuplePtr &tuple, const AbstractSl *step_value = CheckSliceMember(slice->step(), step_default, step_name); if (*step_value == 0) { - MS_LOG(EXCEPTION) << "TupleSlice require the step value could not be 0, but got 0."; + MS_EXCEPTION(ValueError) << "TupleSlice require the step value could not be 0, but got 0."; } if (*step_value < 0) { @@ -941,8 +941,8 @@ void GenerateTupleSliceParameter(const AbstractTuplePtr &tuple, const AbstractSl *stop_index = CheckSliceMember(slice->stop(), stop_default, stop_name); if (!CheckIndexInRange(*start_index, -tuple_size, tuple_size - 1) || !CheckIndexInRange(*stop_index, -tuple_size - 1, tuple_size)) { - MS_LOG(EXCEPTION) << "TupleSlice the start index " << *start_index << " or end end index " << *stop_index - << " out of range, tuple size " << tuple_size << "."; + MS_EXCEPTION(ValueError) << "TupleSlice the start index " << *start_index << " or end end index " << *stop_index + << " out of range, tuple size " << tuple_size << "."; } *start_index = GetPositiveIndex(*start_index, tuple_size); diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse_base.h b/mindspore/ccsrc/pipeline/jit/parse/parse_base.h index d3c6851ed..d2c8d7a2f 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse_base.h +++ b/mindspore/ccsrc/pipeline/jit/parse/parse_base.h @@ -69,6 +69,7 @@ const char PYTHON_MOD_GET_MODULE_NAMESPACE[] = "get_module_namespace"; const char PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL[] = "get_class_member_namespace_symbol"; const char PYTHON_MOD_GET_PARSE_METHOD[] = "get_parse_method_of_class"; const char PYTHON_MOD_GET_BPROP_METHOD[] = "get_bprop_method_of_class"; +const char PYTHON_MOD_CONVERT_TO_MS_TENSOR[] = "convert_to_ms_tensor"; const char PYTHON_PARSE_GET_ARGS[] = "get_args"; const char PYTHON_PARSE_GET_ARGS_DEFAULT_VALUES[] = "get_args_default_values"; diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc index 1cd9ecdb3..e35d5e761 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc @@ -226,11 +226,11 @@ static AbstractBasePtrList GetUnpackGraphSpecArgsList(AbstractBasePtrList args_s for (size_t index = 0; index < specialize_args_before_unpack.size(); index++) { MS_EXCEPTION_IF_NULL(specialize_args_before_unpack[index]); if (specialize_args_before_unpack[index]->isa()) { - AbstractTuplePtr arg_tuple = specialize_args_before_unpack[index]->cast(); + auto arg_tuple = specialize_args_before_unpack[index]->cast(); std::transform(arg_tuple->elements().begin(), arg_tuple->elements().end(), std::back_inserter(graph_specialize_args), [](AbstractBasePtr abs) { return abs; }); } else if (specialize_args_before_unpack[index]->isa()) { - AbstractDictionaryPtr arg_dict = specialize_args_before_unpack[index]->cast(); + auto arg_dict = specialize_args_before_unpack[index]->cast(); auto dict_elems = arg_dict->elements(); (void)std::transform( dict_elems.begin(), dict_elems.end(), std::back_inserter(graph_specialize_args), @@ -353,7 +353,7 @@ EvalResultPtr MixedPrecisionCastEvaluator::Run(AnalysisEnginePtr engine, const C } auto out_node = out_conf->node()->cast(); const auto &out_node_inputs = out_node->inputs(); - if (out_node->inputs().size() == 0 || (out_node_inputs.size() - 1) != args_conf_list.size()) { + if (out_node->inputs().empty() || (out_node_inputs.size() - 1) != args_conf_list.size()) { MS_LOG(EXCEPTION) << "MixedPrecisionCast" << " args size should equal to inputs size minus 1, but args size " << args_conf_list.size() << ", inputs size " << out_node_inputs.size(); diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index b83306a92..579af6e94 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -115,12 +115,12 @@ inline ValuePtr PyAttrValue(const py::object &obj) { static std::string GetId(const py::object &obj) { py::object to_process = obj; std::string prefix = ""; - if (py::isinstance(to_process)) { + if (py::isinstance(to_process) || py::isinstance(to_process)) { auto p_list = py::cast(to_process); - if (p_list.size() == 0) { + if (p_list.empty()) { return "empty"; } - prefix = "tuple:"; + prefix = py::isinstance(to_process) ? "tuple:" : "list"; std::string key = ""; for (size_t i = 0; i < p_list.size(); ++i) { key += std::string(py::str(GetId(p_list[i]))) + ":"; @@ -738,6 +738,21 @@ AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj) { return node; } +std::string PynativeExecutor::GetCellId(const py::object &cell, const py::args &args) { + auto cell_id = GetId(cell); + for (size_t i = 0; i < args.size(); i++) { + std::string arg_id = GetId(args[i]); + if (node_abs_map_.find(arg_id) != node_abs_map_.end()) { + cell_id += node_abs_map_[arg_id]->ToString(); + } else { + AbstractBasePtr abs = abstract::FromValueInside(PyAttrValue(args[i]), true); + cell_id += abs->ToString(); + node_abs_map_[arg_id] = abs; + } + } + return cell_id; +} + py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) { MS_LOG(INFO) << "RunOp start, op name is: " << op_exec_info->op_name; mindspore::parse::python_adapter::set_python_env_flag(true); @@ -785,8 +800,8 @@ py::tuple PynativeExecutor::RunOpInner(const py::args &args) { } auto cnode = PynativeExecutor::GetInstance()->MakeCNode(op_exec_info, &op_masks, &args_spec_list); bool is_find = false; - if (prim_abs_list.find(prim->id()) != prim_abs_list.end()) { - auto abs_list = prim_abs_list[prim->id()]; + if (prim_abs_list_.find(prim->id()) != prim_abs_list_.end()) { + auto abs_list = prim_abs_list_[prim->id()]; MS_LOG(DEBUG) << "match prim input args " << op_exec_info->op_name << mindspore::ToString(args_spec_list); if (abs_list.find(args_spec_list) != abs_list.end()) { MS_LOG(DEBUG) << "match prim ok" << op_exec_info->op_name; @@ -827,7 +842,7 @@ py::tuple PynativeExecutor::RunOpInner(const py::args &args) { if (!is_find) { // const_value need infer every step - auto &out = prim_abs_list[prim->id()]; + auto &out = prim_abs_list_[prim->id()]; out[args_spec_list].abs = op_exec_info->abstract; out[args_spec_list].attrs = prim->evaluate_added_attrs(); MS_LOG(DEBUG) << "set prim " << op_exec_info->op_name << mindspore::ToString(args_spec_list); @@ -890,7 +905,7 @@ PynativeExecutor::~PynativeExecutor() { ClearRes(); } PynativeExecutor::PynativeExecutor() { grad_flag_ = false; } void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &args) { - auto cell_id = GetId(cell); + auto cell_id = GetCellId(cell, args); if (cell_graph_map_.count(cell_id) != 0) { if (cell_resource_map_.find(cell_id) != cell_resource_map_.end()) { resource_ = cell_resource_map_[cell_id]; @@ -1016,7 +1031,7 @@ void PynativeExecutor::Popp() { } void PynativeExecutor::EndGraphInner(const py::object &cell, const py::object &out, const py::args &args) { - auto cell_id = GetId(cell); + auto cell_id = GetCellId(cell, args); if (cell_graph_map_.count(cell_id) != 0) { MS_LOG(DEBUG) << "Endgraph already compiled"; return; @@ -1078,7 +1093,7 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje inputs.push_back(input); } auto out_cnode = curr_g_->NewCNode(inputs); - set_pyobj(curr_g_, GetId(cell)); + set_pyobj(curr_g_, GetCellId(cell, args)); if (py::isinstance(out)) { auto out_list = py::cast(out); auto out_size = static_cast(out_list.size()); @@ -1169,7 +1184,7 @@ void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::obje MS_LOG(INFO) << "GradNet start" << args.size(); std::size_t size = args.size(); - auto cell_id = GetId(cell); + std::string cell_id = GetCellId(cell, args); if (graph_map_.count(cell_id) != 0) { MS_LOG(DEBUG) << "GradNet already compiled"; return; diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h index 246ceada1..11d651c77 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h @@ -92,6 +92,7 @@ class PynativeExecutor : public std::enable_shared_from_this { void set_grad_flag(bool flag) { grad_flag_ = flag; } AnfNodePtr GetInput(const py::object &obj, bool op_mask); AnfNodePtr GetObjNode(const py::object &obj); + std::string GetCellId(const py::object &obj, const py::args &args); FuncGraphPtr curr_g() { return curr_g_; } void set_pyobj(FuncGraphPtr g, const std::string obj) { graph_info_map_[g].objects.push_back(obj); } void set_obj_node_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node) { @@ -141,7 +142,7 @@ class PynativeExecutor : public std::enable_shared_from_this { FuncGraphPtr top_g_; FuncGraphPtr df_builder_; FuncGraphPtr curr_g_; - std::unordered_map prim_abs_list; + std::unordered_map prim_abs_list_; }; using PynativeExecutorPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/utils/primitive_py.cc b/mindspore/ccsrc/utils/primitive_py.cc index 721f52c67..8b1031e18 100644 --- a/mindspore/ccsrc/utils/primitive_py.cc +++ b/mindspore/ccsrc/utils/primitive_py.cc @@ -78,12 +78,19 @@ py::function PrimitivePy::GetBpropFunction() { } BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const { - auto py_args = ConvertDatatoPyTuple(args); + py::tuple py_args = ConvertDatatoPyTuple(args); py::object obj; bool is_bprop = this->HasAttr(kBpropAttrName); if (is_bprop) { SyncData(py_args); - obj = hook_(*py_args); + py::tuple convert_args(py_args.size()); + for (size_t i = 0; i < py_args.size(); i++) { + convert_args[i] = py::isinstance(py_args[i]) + ? parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, + parse::PYTHON_MOD_CONVERT_TO_MS_TENSOR, py_args[i]) + : py_args[i]; + } + obj = hook_(*convert_args); return std::make_shared(obj); } SyncData(py_args[2]); diff --git a/mindspore/common/tensor.py b/mindspore/common/tensor.py index 5730c9f8c..0a975802f 100644 --- a/mindspore/common/tensor.py +++ b/mindspore/common/tensor.py @@ -210,12 +210,12 @@ class Tensor(Tensor_): @property def shape(self): - """The shape of tensor.""" + """The shape of tensor is a tuple.""" return self._shape @property def dtype(self): - """The dtype of tensor.""" + """The dtype of tensor is a mindspore type.""" return self._dtype @property @@ -248,6 +248,8 @@ class Tensor(Tensor_): Tensor, has the same data type as x. """ + if axis is None: + axis = () return tensor_operator_registry.get('all')(keep_dims)(self, axis) def any(self, axis=(), keep_dims=False): @@ -264,6 +266,8 @@ class Tensor(Tensor_): Tensor, has the same data type as x. """ + if axis is None: + axis = () return tensor_operator_registry.get('any')(keep_dims)(self, axis) diff --git a/mindspore/ops/_grad/grad_array_ops.py b/mindspore/ops/_grad/grad_array_ops.py index d4079f38e..043d3aeb9 100644 --- a/mindspore/ops/_grad/grad_array_ops.py +++ b/mindspore/ops/_grad/grad_array_ops.py @@ -693,7 +693,7 @@ def get_bprop_unsorted_segment_min(self): select = P.Select() def bprop(x, segment_ids, num_segments, out, dout): - gathered_outputs, zero_clipped_indices, is_positive = _GatherDropNegatives(out, segment_ids) + gathered_outputs, zero_clipped_indices, is_positive = _GatherDropNegatives(out, segment_ids, None, None) is_selected = equal(x, gathered_outputs) is_selected = logical_and(is_selected, is_positive) num_selected = unsorted_segment_sum(cast(is_selected, get_dtype(dout)), -- GitLab