提交 02d6e3a4 编写于 作者: B buxue

fix bugs

上级 dc961e46
...@@ -22,7 +22,7 @@ from .parser import (Parser, create_obj_instance, generate_scope, ...@@ -22,7 +22,7 @@ from .parser import (Parser, create_obj_instance, generate_scope,
get_dataclass_attributes, get_dataclass_methods, get_obj_id, get_dataclass_attributes, get_dataclass_methods, get_obj_id,
get_module_namespace, get_obj_type, get_object_key, get_module_namespace, get_obj_type, get_object_key,
get_parse_method_of_class, get_scope_name, get_parse_method_of_class, get_scope_name,
is_class_member, parse_cb, resolve_symbol) is_class_member, parse_cb, resolve_symbol, convert_to_ms_tensor)
from .serialize import * from .serialize import *
__all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol', __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol',
...@@ -30,4 +30,4 @@ __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', ...@@ -30,4 +30,4 @@ __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class',
'get_obj_type', 'get_obj_id', 'create_obj_instance', 'get_module_namespace', 'get_obj_type', 'get_obj_id', 'create_obj_instance', 'get_module_namespace',
'get_class_member_namespace_symbol', 'get_obj_id', 'Parser', 'get_dataclass_attributes', 'get_class_member_namespace_symbol', 'get_obj_id', 'Parser', 'get_dataclass_attributes',
'get_dataclass_methods', 'dump_obj', 'load_obj', 'get_dataclass_methods', 'get_scope_name', 'get_dataclass_methods', 'dump_obj', 'load_obj', 'get_dataclass_methods', 'get_scope_name',
'create_slice_obj'] 'create_slice_obj', 'convert_to_ms_tensor']
...@@ -25,6 +25,7 @@ from dataclasses import is_dataclass ...@@ -25,6 +25,7 @@ from dataclasses import is_dataclass
import asttokens import asttokens
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import log as logger from mindspore import log as logger
from mindspore import Tensor as MsTensor
from mindspore import ops from mindspore import ops
from mindspore.common.dtype import pytype_to_dtype from mindspore.common.dtype import pytype_to_dtype
from mindspore.common.api import _MindSporeFunction from mindspore.common.api import _MindSporeFunction
...@@ -316,6 +317,11 @@ def get_dataclass_methods(cls): ...@@ -316,6 +317,11 @@ def get_dataclass_methods(cls):
return methods return methods
def convert_to_ms_tensor(data):
"""Convert C++ tensor to mindspore tensor."""
return MsTensor(data)
class Parser: class Parser:
""" """
Parser python code to ast tree. Parser python code to ast tree.
......
...@@ -929,7 +929,7 @@ void GenerateTupleSliceParameter(const AbstractTuplePtr &tuple, const AbstractSl ...@@ -929,7 +929,7 @@ void GenerateTupleSliceParameter(const AbstractTuplePtr &tuple, const AbstractSl
*step_value = CheckSliceMember(slice->step(), step_default, step_name); *step_value = CheckSliceMember(slice->step(), step_default, step_name);
if (*step_value == 0) { if (*step_value == 0) {
MS_LOG(EXCEPTION) << "TupleSlice require the step value could not be 0, but got 0."; MS_EXCEPTION(ValueError) << "TupleSlice require the step value could not be 0, but got 0.";
} }
if (*step_value < 0) { if (*step_value < 0) {
...@@ -941,8 +941,8 @@ void GenerateTupleSliceParameter(const AbstractTuplePtr &tuple, const AbstractSl ...@@ -941,8 +941,8 @@ void GenerateTupleSliceParameter(const AbstractTuplePtr &tuple, const AbstractSl
*stop_index = CheckSliceMember(slice->stop(), stop_default, stop_name); *stop_index = CheckSliceMember(slice->stop(), stop_default, stop_name);
if (!CheckIndexInRange(*start_index, -tuple_size, tuple_size - 1) || if (!CheckIndexInRange(*start_index, -tuple_size, tuple_size - 1) ||
!CheckIndexInRange(*stop_index, -tuple_size - 1, tuple_size)) { !CheckIndexInRange(*stop_index, -tuple_size - 1, tuple_size)) {
MS_LOG(EXCEPTION) << "TupleSlice the start index " << *start_index << " or end end index " << *stop_index MS_EXCEPTION(ValueError) << "TupleSlice the start index " << *start_index << " or end end index " << *stop_index
<< " out of range, tuple size " << tuple_size << "."; << " out of range, tuple size " << tuple_size << ".";
} }
*start_index = GetPositiveIndex(*start_index, tuple_size); *start_index = GetPositiveIndex(*start_index, tuple_size);
......
...@@ -69,6 +69,7 @@ const char PYTHON_MOD_GET_MODULE_NAMESPACE[] = "get_module_namespace"; ...@@ -69,6 +69,7 @@ const char PYTHON_MOD_GET_MODULE_NAMESPACE[] = "get_module_namespace";
const char PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL[] = "get_class_member_namespace_symbol"; const char PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL[] = "get_class_member_namespace_symbol";
const char PYTHON_MOD_GET_PARSE_METHOD[] = "get_parse_method_of_class"; const char PYTHON_MOD_GET_PARSE_METHOD[] = "get_parse_method_of_class";
const char PYTHON_MOD_GET_BPROP_METHOD[] = "get_bprop_method_of_class"; const char PYTHON_MOD_GET_BPROP_METHOD[] = "get_bprop_method_of_class";
const char PYTHON_MOD_CONVERT_TO_MS_TENSOR[] = "convert_to_ms_tensor";
const char PYTHON_PARSE_GET_ARGS[] = "get_args"; const char PYTHON_PARSE_GET_ARGS[] = "get_args";
const char PYTHON_PARSE_GET_ARGS_DEFAULT_VALUES[] = "get_args_default_values"; const char PYTHON_PARSE_GET_ARGS_DEFAULT_VALUES[] = "get_args_default_values";
......
...@@ -226,11 +226,11 @@ static AbstractBasePtrList GetUnpackGraphSpecArgsList(AbstractBasePtrList args_s ...@@ -226,11 +226,11 @@ static AbstractBasePtrList GetUnpackGraphSpecArgsList(AbstractBasePtrList args_s
for (size_t index = 0; index < specialize_args_before_unpack.size(); index++) { for (size_t index = 0; index < specialize_args_before_unpack.size(); index++) {
MS_EXCEPTION_IF_NULL(specialize_args_before_unpack[index]); MS_EXCEPTION_IF_NULL(specialize_args_before_unpack[index]);
if (specialize_args_before_unpack[index]->isa<AbstractTuple>()) { if (specialize_args_before_unpack[index]->isa<AbstractTuple>()) {
AbstractTuplePtr arg_tuple = specialize_args_before_unpack[index]->cast<AbstractTuplePtr>(); auto arg_tuple = specialize_args_before_unpack[index]->cast<AbstractTuplePtr>();
std::transform(arg_tuple->elements().begin(), arg_tuple->elements().end(), std::transform(arg_tuple->elements().begin(), arg_tuple->elements().end(),
std::back_inserter(graph_specialize_args), [](AbstractBasePtr abs) { return abs; }); std::back_inserter(graph_specialize_args), [](AbstractBasePtr abs) { return abs; });
} else if (specialize_args_before_unpack[index]->isa<AbstractDictionary>()) { } else if (specialize_args_before_unpack[index]->isa<AbstractDictionary>()) {
AbstractDictionaryPtr arg_dict = specialize_args_before_unpack[index]->cast<AbstractDictionaryPtr>(); auto arg_dict = specialize_args_before_unpack[index]->cast<AbstractDictionaryPtr>();
auto dict_elems = arg_dict->elements(); auto dict_elems = arg_dict->elements();
(void)std::transform( (void)std::transform(
dict_elems.begin(), dict_elems.end(), std::back_inserter(graph_specialize_args), dict_elems.begin(), dict_elems.end(), std::back_inserter(graph_specialize_args),
...@@ -353,7 +353,7 @@ EvalResultPtr MixedPrecisionCastEvaluator::Run(AnalysisEnginePtr engine, const C ...@@ -353,7 +353,7 @@ EvalResultPtr MixedPrecisionCastEvaluator::Run(AnalysisEnginePtr engine, const C
} }
auto out_node = out_conf->node()->cast<CNodePtr>(); auto out_node = out_conf->node()->cast<CNodePtr>();
const auto &out_node_inputs = out_node->inputs(); const auto &out_node_inputs = out_node->inputs();
if (out_node->inputs().size() == 0 || (out_node_inputs.size() - 1) != args_conf_list.size()) { if (out_node->inputs().empty() || (out_node_inputs.size() - 1) != args_conf_list.size()) {
MS_LOG(EXCEPTION) << "MixedPrecisionCast" MS_LOG(EXCEPTION) << "MixedPrecisionCast"
<< " args size should equal to inputs size minus 1, but args size " << args_conf_list.size() << " args size should equal to inputs size minus 1, but args size " << args_conf_list.size()
<< ", inputs size " << out_node_inputs.size(); << ", inputs size " << out_node_inputs.size();
......
...@@ -115,12 +115,12 @@ inline ValuePtr PyAttrValue(const py::object &obj) { ...@@ -115,12 +115,12 @@ inline ValuePtr PyAttrValue(const py::object &obj) {
static std::string GetId(const py::object &obj) { static std::string GetId(const py::object &obj) {
py::object to_process = obj; py::object to_process = obj;
std::string prefix = ""; std::string prefix = "";
if (py::isinstance<py::tuple>(to_process)) { if (py::isinstance<py::tuple>(to_process) || py::isinstance<py::list>(to_process)) {
auto p_list = py::cast<py::tuple>(to_process); auto p_list = py::cast<py::tuple>(to_process);
if (p_list.size() == 0) { if (p_list.empty()) {
return "empty"; return "empty";
} }
prefix = "tuple:"; prefix = py::isinstance<py::tuple>(to_process) ? "tuple:" : "list";
std::string key = ""; std::string key = "";
for (size_t i = 0; i < p_list.size(); ++i) { for (size_t i = 0; i < p_list.size(); ++i) {
key += std::string(py::str(GetId(p_list[i]))) + ":"; key += std::string(py::str(GetId(p_list[i]))) + ":";
...@@ -738,6 +738,21 @@ AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj) { ...@@ -738,6 +738,21 @@ AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj) {
return node; return node;
} }
std::string PynativeExecutor::GetCellId(const py::object &cell, const py::args &args) {
auto cell_id = GetId(cell);
for (size_t i = 0; i < args.size(); i++) {
std::string arg_id = GetId(args[i]);
if (node_abs_map_.find(arg_id) != node_abs_map_.end()) {
cell_id += node_abs_map_[arg_id]->ToString();
} else {
AbstractBasePtr abs = abstract::FromValueInside(PyAttrValue(args[i]), true);
cell_id += abs->ToString();
node_abs_map_[arg_id] = abs;
}
}
return cell_id;
}
py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) { py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) {
MS_LOG(INFO) << "RunOp start, op name is: " << op_exec_info->op_name; MS_LOG(INFO) << "RunOp start, op name is: " << op_exec_info->op_name;
mindspore::parse::python_adapter::set_python_env_flag(true); mindspore::parse::python_adapter::set_python_env_flag(true);
...@@ -785,8 +800,8 @@ py::tuple PynativeExecutor::RunOpInner(const py::args &args) { ...@@ -785,8 +800,8 @@ py::tuple PynativeExecutor::RunOpInner(const py::args &args) {
} }
auto cnode = PynativeExecutor::GetInstance()->MakeCNode(op_exec_info, &op_masks, &args_spec_list); auto cnode = PynativeExecutor::GetInstance()->MakeCNode(op_exec_info, &op_masks, &args_spec_list);
bool is_find = false; bool is_find = false;
if (prim_abs_list.find(prim->id()) != prim_abs_list.end()) { if (prim_abs_list_.find(prim->id()) != prim_abs_list_.end()) {
auto abs_list = prim_abs_list[prim->id()]; auto abs_list = prim_abs_list_[prim->id()];
MS_LOG(DEBUG) << "match prim input args " << op_exec_info->op_name << mindspore::ToString(args_spec_list); MS_LOG(DEBUG) << "match prim input args " << op_exec_info->op_name << mindspore::ToString(args_spec_list);
if (abs_list.find(args_spec_list) != abs_list.end()) { if (abs_list.find(args_spec_list) != abs_list.end()) {
MS_LOG(DEBUG) << "match prim ok" << op_exec_info->op_name; MS_LOG(DEBUG) << "match prim ok" << op_exec_info->op_name;
...@@ -827,7 +842,7 @@ py::tuple PynativeExecutor::RunOpInner(const py::args &args) { ...@@ -827,7 +842,7 @@ py::tuple PynativeExecutor::RunOpInner(const py::args &args) {
if (!is_find) { if (!is_find) {
// const_value need infer every step // const_value need infer every step
auto &out = prim_abs_list[prim->id()]; auto &out = prim_abs_list_[prim->id()];
out[args_spec_list].abs = op_exec_info->abstract; out[args_spec_list].abs = op_exec_info->abstract;
out[args_spec_list].attrs = prim->evaluate_added_attrs(); out[args_spec_list].attrs = prim->evaluate_added_attrs();
MS_LOG(DEBUG) << "set prim " << op_exec_info->op_name << mindspore::ToString(args_spec_list); MS_LOG(DEBUG) << "set prim " << op_exec_info->op_name << mindspore::ToString(args_spec_list);
...@@ -890,7 +905,7 @@ PynativeExecutor::~PynativeExecutor() { ClearRes(); } ...@@ -890,7 +905,7 @@ PynativeExecutor::~PynativeExecutor() { ClearRes(); }
PynativeExecutor::PynativeExecutor() { grad_flag_ = false; } PynativeExecutor::PynativeExecutor() { grad_flag_ = false; }
void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &args) { void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &args) {
auto cell_id = GetId(cell); auto cell_id = GetCellId(cell, args);
if (cell_graph_map_.count(cell_id) != 0) { if (cell_graph_map_.count(cell_id) != 0) {
if (cell_resource_map_.find(cell_id) != cell_resource_map_.end()) { if (cell_resource_map_.find(cell_id) != cell_resource_map_.end()) {
resource_ = cell_resource_map_[cell_id]; resource_ = cell_resource_map_[cell_id];
...@@ -1016,7 +1031,7 @@ void PynativeExecutor::Popp() { ...@@ -1016,7 +1031,7 @@ void PynativeExecutor::Popp() {
} }
void PynativeExecutor::EndGraphInner(const py::object &cell, const py::object &out, const py::args &args) { void PynativeExecutor::EndGraphInner(const py::object &cell, const py::object &out, const py::args &args) {
auto cell_id = GetId(cell); auto cell_id = GetCellId(cell, args);
if (cell_graph_map_.count(cell_id) != 0) { if (cell_graph_map_.count(cell_id) != 0) {
MS_LOG(DEBUG) << "Endgraph already compiled"; MS_LOG(DEBUG) << "Endgraph already compiled";
return; return;
...@@ -1078,7 +1093,7 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje ...@@ -1078,7 +1093,7 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje
inputs.push_back(input); inputs.push_back(input);
} }
auto out_cnode = curr_g_->NewCNode(inputs); auto out_cnode = curr_g_->NewCNode(inputs);
set_pyobj(curr_g_, GetId(cell)); set_pyobj(curr_g_, GetCellId(cell, args));
if (py::isinstance<py::tuple>(out)) { if (py::isinstance<py::tuple>(out)) {
auto out_list = py::cast<py::tuple>(out); auto out_list = py::cast<py::tuple>(out);
auto out_size = static_cast<int>(out_list.size()); auto out_size = static_cast<int>(out_list.size());
...@@ -1169,7 +1184,7 @@ void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::obje ...@@ -1169,7 +1184,7 @@ void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::obje
MS_LOG(INFO) << "GradNet start" << args.size(); MS_LOG(INFO) << "GradNet start" << args.size();
std::size_t size = args.size(); std::size_t size = args.size();
auto cell_id = GetId(cell); std::string cell_id = GetCellId(cell, args);
if (graph_map_.count(cell_id) != 0) { if (graph_map_.count(cell_id) != 0) {
MS_LOG(DEBUG) << "GradNet already compiled"; MS_LOG(DEBUG) << "GradNet already compiled";
return; return;
......
...@@ -92,6 +92,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> { ...@@ -92,6 +92,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
void set_grad_flag(bool flag) { grad_flag_ = flag; } void set_grad_flag(bool flag) { grad_flag_ = flag; }
AnfNodePtr GetInput(const py::object &obj, bool op_mask); AnfNodePtr GetInput(const py::object &obj, bool op_mask);
AnfNodePtr GetObjNode(const py::object &obj); AnfNodePtr GetObjNode(const py::object &obj);
std::string GetCellId(const py::object &obj, const py::args &args);
FuncGraphPtr curr_g() { return curr_g_; } FuncGraphPtr curr_g() { return curr_g_; }
void set_pyobj(FuncGraphPtr g, const std::string obj) { graph_info_map_[g].objects.push_back(obj); } void set_pyobj(FuncGraphPtr g, const std::string obj) { graph_info_map_[g].objects.push_back(obj); }
void set_obj_node_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node) { void set_obj_node_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node) {
...@@ -141,7 +142,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> { ...@@ -141,7 +142,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
FuncGraphPtr top_g_; FuncGraphPtr top_g_;
FuncGraphPtr df_builder_; FuncGraphPtr df_builder_;
FuncGraphPtr curr_g_; FuncGraphPtr curr_g_;
std::unordered_map<std::string, AbstractListMap> prim_abs_list; std::unordered_map<std::string, AbstractListMap> prim_abs_list_;
}; };
using PynativeExecutorPtr = std::shared_ptr<PynativeExecutor>; using PynativeExecutorPtr = std::shared_ptr<PynativeExecutor>;
......
...@@ -78,12 +78,19 @@ py::function PrimitivePy::GetBpropFunction() { ...@@ -78,12 +78,19 @@ py::function PrimitivePy::GetBpropFunction() {
} }
BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const { BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const {
auto py_args = ConvertDatatoPyTuple(args); py::tuple py_args = ConvertDatatoPyTuple(args);
py::object obj; py::object obj;
bool is_bprop = this->HasAttr(kBpropAttrName); bool is_bprop = this->HasAttr(kBpropAttrName);
if (is_bprop) { if (is_bprop) {
SyncData(py_args); SyncData(py_args);
obj = hook_(*py_args); py::tuple convert_args(py_args.size());
for (size_t i = 0; i < py_args.size(); i++) {
convert_args[i] = py::isinstance<tensor::Tensor>(py_args[i])
? parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE,
parse::PYTHON_MOD_CONVERT_TO_MS_TENSOR, py_args[i])
: py_args[i];
}
obj = hook_(*convert_args);
return std::make_shared<PyObjectRef>(obj); return std::make_shared<PyObjectRef>(obj);
} }
SyncData(py_args[2]); SyncData(py_args[2]);
......
...@@ -210,12 +210,12 @@ class Tensor(Tensor_): ...@@ -210,12 +210,12 @@ class Tensor(Tensor_):
@property @property
def shape(self): def shape(self):
"""The shape of tensor.""" """The shape of tensor is a tuple."""
return self._shape return self._shape
@property @property
def dtype(self): def dtype(self):
"""The dtype of tensor.""" """The dtype of tensor is a mindspore type."""
return self._dtype return self._dtype
@property @property
...@@ -248,6 +248,8 @@ class Tensor(Tensor_): ...@@ -248,6 +248,8 @@ class Tensor(Tensor_):
Tensor, has the same data type as x. Tensor, has the same data type as x.
""" """
if axis is None:
axis = ()
return tensor_operator_registry.get('all')(keep_dims)(self, axis) return tensor_operator_registry.get('all')(keep_dims)(self, axis)
def any(self, axis=(), keep_dims=False): def any(self, axis=(), keep_dims=False):
...@@ -264,6 +266,8 @@ class Tensor(Tensor_): ...@@ -264,6 +266,8 @@ class Tensor(Tensor_):
Tensor, has the same data type as x. Tensor, has the same data type as x.
""" """
if axis is None:
axis = ()
return tensor_operator_registry.get('any')(keep_dims)(self, axis) return tensor_operator_registry.get('any')(keep_dims)(self, axis)
......
...@@ -693,7 +693,7 @@ def get_bprop_unsorted_segment_min(self): ...@@ -693,7 +693,7 @@ def get_bprop_unsorted_segment_min(self):
select = P.Select() select = P.Select()
def bprop(x, segment_ids, num_segments, out, dout): def bprop(x, segment_ids, num_segments, out, dout):
gathered_outputs, zero_clipped_indices, is_positive = _GatherDropNegatives(out, segment_ids) gathered_outputs, zero_clipped_indices, is_positive = _GatherDropNegatives(out, segment_ids, None, None)
is_selected = equal(x, gathered_outputs) is_selected = equal(x, gathered_outputs)
is_selected = logical_and(is_selected, is_positive) is_selected = logical_and(is_selected, is_positive)
num_selected = unsorted_segment_sum(cast(is_selected, get_dtype(dout)), num_selected = unsorted_segment_sum(cast(is_selected, get_dtype(dout)),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册