提交 02d6e3a4 编写于 作者: B buxue

fix bugs

上级 dc961e46
......@@ -22,7 +22,7 @@ from .parser import (Parser, create_obj_instance, generate_scope,
get_dataclass_attributes, get_dataclass_methods, get_obj_id,
get_module_namespace, get_obj_type, get_object_key,
get_parse_method_of_class, get_scope_name,
is_class_member, parse_cb, resolve_symbol)
is_class_member, parse_cb, resolve_symbol, convert_to_ms_tensor)
from .serialize import *
__all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol',
......@@ -30,4 +30,4 @@ __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class',
'get_obj_type', 'get_obj_id', 'create_obj_instance', 'get_module_namespace',
'get_class_member_namespace_symbol', 'get_obj_id', 'Parser', 'get_dataclass_attributes',
'get_dataclass_methods', 'dump_obj', 'load_obj', 'get_dataclass_methods', 'get_scope_name',
'create_slice_obj']
'create_slice_obj', 'convert_to_ms_tensor']
......@@ -25,6 +25,7 @@ from dataclasses import is_dataclass
import asttokens
import mindspore.nn as nn
from mindspore import log as logger
from mindspore import Tensor as MsTensor
from mindspore import ops
from mindspore.common.dtype import pytype_to_dtype
from mindspore.common.api import _MindSporeFunction
......@@ -316,6 +317,11 @@ def get_dataclass_methods(cls):
return methods
def convert_to_ms_tensor(data):
"""Convert C++ tensor to mindspore tensor."""
return MsTensor(data)
class Parser:
"""
Parser python code to ast tree.
......
......@@ -929,7 +929,7 @@ void GenerateTupleSliceParameter(const AbstractTuplePtr &tuple, const AbstractSl
*step_value = CheckSliceMember(slice->step(), step_default, step_name);
if (*step_value == 0) {
MS_LOG(EXCEPTION) << "TupleSlice require the step value could not be 0, but got 0.";
MS_EXCEPTION(ValueError) << "TupleSlice require the step value could not be 0, but got 0.";
}
if (*step_value < 0) {
......@@ -941,7 +941,7 @@ void GenerateTupleSliceParameter(const AbstractTuplePtr &tuple, const AbstractSl
*stop_index = CheckSliceMember(slice->stop(), stop_default, stop_name);
if (!CheckIndexInRange(*start_index, -tuple_size, tuple_size - 1) ||
!CheckIndexInRange(*stop_index, -tuple_size - 1, tuple_size)) {
MS_LOG(EXCEPTION) << "TupleSlice the start index " << *start_index << " or end end index " << *stop_index
MS_EXCEPTION(ValueError) << "TupleSlice the start index " << *start_index << " or end end index " << *stop_index
<< " out of range, tuple size " << tuple_size << ".";
}
......
......@@ -69,6 +69,7 @@ const char PYTHON_MOD_GET_MODULE_NAMESPACE[] = "get_module_namespace";
const char PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL[] = "get_class_member_namespace_symbol";
const char PYTHON_MOD_GET_PARSE_METHOD[] = "get_parse_method_of_class";
const char PYTHON_MOD_GET_BPROP_METHOD[] = "get_bprop_method_of_class";
const char PYTHON_MOD_CONVERT_TO_MS_TENSOR[] = "convert_to_ms_tensor";
const char PYTHON_PARSE_GET_ARGS[] = "get_args";
const char PYTHON_PARSE_GET_ARGS_DEFAULT_VALUES[] = "get_args_default_values";
......
......@@ -226,11 +226,11 @@ static AbstractBasePtrList GetUnpackGraphSpecArgsList(AbstractBasePtrList args_s
for (size_t index = 0; index < specialize_args_before_unpack.size(); index++) {
MS_EXCEPTION_IF_NULL(specialize_args_before_unpack[index]);
if (specialize_args_before_unpack[index]->isa<AbstractTuple>()) {
AbstractTuplePtr arg_tuple = specialize_args_before_unpack[index]->cast<AbstractTuplePtr>();
auto arg_tuple = specialize_args_before_unpack[index]->cast<AbstractTuplePtr>();
std::transform(arg_tuple->elements().begin(), arg_tuple->elements().end(),
std::back_inserter(graph_specialize_args), [](AbstractBasePtr abs) { return abs; });
} else if (specialize_args_before_unpack[index]->isa<AbstractDictionary>()) {
AbstractDictionaryPtr arg_dict = specialize_args_before_unpack[index]->cast<AbstractDictionaryPtr>();
auto arg_dict = specialize_args_before_unpack[index]->cast<AbstractDictionaryPtr>();
auto dict_elems = arg_dict->elements();
(void)std::transform(
dict_elems.begin(), dict_elems.end(), std::back_inserter(graph_specialize_args),
......@@ -353,7 +353,7 @@ EvalResultPtr MixedPrecisionCastEvaluator::Run(AnalysisEnginePtr engine, const C
}
auto out_node = out_conf->node()->cast<CNodePtr>();
const auto &out_node_inputs = out_node->inputs();
if (out_node->inputs().size() == 0 || (out_node_inputs.size() - 1) != args_conf_list.size()) {
if (out_node->inputs().empty() || (out_node_inputs.size() - 1) != args_conf_list.size()) {
MS_LOG(EXCEPTION) << "MixedPrecisionCast"
<< " args size should equal to inputs size minus 1, but args size " << args_conf_list.size()
<< ", inputs size " << out_node_inputs.size();
......
......@@ -115,12 +115,12 @@ inline ValuePtr PyAttrValue(const py::object &obj) {
static std::string GetId(const py::object &obj) {
py::object to_process = obj;
std::string prefix = "";
if (py::isinstance<py::tuple>(to_process)) {
if (py::isinstance<py::tuple>(to_process) || py::isinstance<py::list>(to_process)) {
auto p_list = py::cast<py::tuple>(to_process);
if (p_list.size() == 0) {
if (p_list.empty()) {
return "empty";
}
prefix = "tuple:";
prefix = py::isinstance<py::tuple>(to_process) ? "tuple:" : "list";
std::string key = "";
for (size_t i = 0; i < p_list.size(); ++i) {
key += std::string(py::str(GetId(p_list[i]))) + ":";
......@@ -738,6 +738,21 @@ AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj) {
return node;
}
std::string PynativeExecutor::GetCellId(const py::object &cell, const py::args &args) {
auto cell_id = GetId(cell);
for (size_t i = 0; i < args.size(); i++) {
std::string arg_id = GetId(args[i]);
if (node_abs_map_.find(arg_id) != node_abs_map_.end()) {
cell_id += node_abs_map_[arg_id]->ToString();
} else {
AbstractBasePtr abs = abstract::FromValueInside(PyAttrValue(args[i]), true);
cell_id += abs->ToString();
node_abs_map_[arg_id] = abs;
}
}
return cell_id;
}
py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) {
MS_LOG(INFO) << "RunOp start, op name is: " << op_exec_info->op_name;
mindspore::parse::python_adapter::set_python_env_flag(true);
......@@ -785,8 +800,8 @@ py::tuple PynativeExecutor::RunOpInner(const py::args &args) {
}
auto cnode = PynativeExecutor::GetInstance()->MakeCNode(op_exec_info, &op_masks, &args_spec_list);
bool is_find = false;
if (prim_abs_list.find(prim->id()) != prim_abs_list.end()) {
auto abs_list = prim_abs_list[prim->id()];
if (prim_abs_list_.find(prim->id()) != prim_abs_list_.end()) {
auto abs_list = prim_abs_list_[prim->id()];
MS_LOG(DEBUG) << "match prim input args " << op_exec_info->op_name << mindspore::ToString(args_spec_list);
if (abs_list.find(args_spec_list) != abs_list.end()) {
MS_LOG(DEBUG) << "match prim ok" << op_exec_info->op_name;
......@@ -827,7 +842,7 @@ py::tuple PynativeExecutor::RunOpInner(const py::args &args) {
if (!is_find) {
// const_value need infer every step
auto &out = prim_abs_list[prim->id()];
auto &out = prim_abs_list_[prim->id()];
out[args_spec_list].abs = op_exec_info->abstract;
out[args_spec_list].attrs = prim->evaluate_added_attrs();
MS_LOG(DEBUG) << "set prim " << op_exec_info->op_name << mindspore::ToString(args_spec_list);
......@@ -890,7 +905,7 @@ PynativeExecutor::~PynativeExecutor() { ClearRes(); }
PynativeExecutor::PynativeExecutor() { grad_flag_ = false; }
void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &args) {
auto cell_id = GetId(cell);
auto cell_id = GetCellId(cell, args);
if (cell_graph_map_.count(cell_id) != 0) {
if (cell_resource_map_.find(cell_id) != cell_resource_map_.end()) {
resource_ = cell_resource_map_[cell_id];
......@@ -1016,7 +1031,7 @@ void PynativeExecutor::Popp() {
}
void PynativeExecutor::EndGraphInner(const py::object &cell, const py::object &out, const py::args &args) {
auto cell_id = GetId(cell);
auto cell_id = GetCellId(cell, args);
if (cell_graph_map_.count(cell_id) != 0) {
MS_LOG(DEBUG) << "Endgraph already compiled";
return;
......@@ -1078,7 +1093,7 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje
inputs.push_back(input);
}
auto out_cnode = curr_g_->NewCNode(inputs);
set_pyobj(curr_g_, GetId(cell));
set_pyobj(curr_g_, GetCellId(cell, args));
if (py::isinstance<py::tuple>(out)) {
auto out_list = py::cast<py::tuple>(out);
auto out_size = static_cast<int>(out_list.size());
......@@ -1169,7 +1184,7 @@ void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::obje
MS_LOG(INFO) << "GradNet start" << args.size();
std::size_t size = args.size();
auto cell_id = GetId(cell);
std::string cell_id = GetCellId(cell, args);
if (graph_map_.count(cell_id) != 0) {
MS_LOG(DEBUG) << "GradNet already compiled";
return;
......
......@@ -92,6 +92,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
void set_grad_flag(bool flag) { grad_flag_ = flag; }
AnfNodePtr GetInput(const py::object &obj, bool op_mask);
AnfNodePtr GetObjNode(const py::object &obj);
std::string GetCellId(const py::object &obj, const py::args &args);
FuncGraphPtr curr_g() { return curr_g_; }
void set_pyobj(FuncGraphPtr g, const std::string obj) { graph_info_map_[g].objects.push_back(obj); }
void set_obj_node_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node) {
......@@ -141,7 +142,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
FuncGraphPtr top_g_;
FuncGraphPtr df_builder_;
FuncGraphPtr curr_g_;
std::unordered_map<std::string, AbstractListMap> prim_abs_list;
std::unordered_map<std::string, AbstractListMap> prim_abs_list_;
};
using PynativeExecutorPtr = std::shared_ptr<PynativeExecutor>;
......
......@@ -78,12 +78,19 @@ py::function PrimitivePy::GetBpropFunction() {
}
BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const {
auto py_args = ConvertDatatoPyTuple(args);
py::tuple py_args = ConvertDatatoPyTuple(args);
py::object obj;
bool is_bprop = this->HasAttr(kBpropAttrName);
if (is_bprop) {
SyncData(py_args);
obj = hook_(*py_args);
py::tuple convert_args(py_args.size());
for (size_t i = 0; i < py_args.size(); i++) {
convert_args[i] = py::isinstance<tensor::Tensor>(py_args[i])
? parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE,
parse::PYTHON_MOD_CONVERT_TO_MS_TENSOR, py_args[i])
: py_args[i];
}
obj = hook_(*convert_args);
return std::make_shared<PyObjectRef>(obj);
}
SyncData(py_args[2]);
......
......@@ -210,12 +210,12 @@ class Tensor(Tensor_):
@property
def shape(self):
"""The shape of tensor."""
"""The shape of tensor is a tuple."""
return self._shape
@property
def dtype(self):
"""The dtype of tensor."""
"""The dtype of tensor is a mindspore type."""
return self._dtype
@property
......@@ -248,6 +248,8 @@ class Tensor(Tensor_):
Tensor, has the same data type as x.
"""
if axis is None:
axis = ()
return tensor_operator_registry.get('all')(keep_dims)(self, axis)
def any(self, axis=(), keep_dims=False):
......@@ -264,6 +266,8 @@ class Tensor(Tensor_):
Tensor, has the same data type as x.
"""
if axis is None:
axis = ()
return tensor_operator_registry.get('any')(keep_dims)(self, axis)
......
......@@ -693,7 +693,7 @@ def get_bprop_unsorted_segment_min(self):
select = P.Select()
def bprop(x, segment_ids, num_segments, out, dout):
gathered_outputs, zero_clipped_indices, is_positive = _GatherDropNegatives(out, segment_ids)
gathered_outputs, zero_clipped_indices, is_positive = _GatherDropNegatives(out, segment_ids, None, None)
is_selected = equal(x, gathered_outputs)
is_selected = logical_and(is_selected, is_positive)
num_selected = unsorted_segment_sum(cast(is_selected, get_dtype(dout)),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册