提交 358982a9 编写于 作者: K kingfo

fix hook and bprop debug issue

上级 fe797aaf
...@@ -113,6 +113,24 @@ def bool_or(x, y): ...@@ -113,6 +113,24 @@ def bool_or(x, y):
"""Implement `bool_or`.""" """Implement `bool_or`."""
return x or y return x or y
def vm_compare(*args):
"""Implement `vm_compare` for tensor."""
obj_str = args[-1]
if obj_str == "shape":
fn = getattr(args[0].asnumpy(), obj_str)
return fn
if len(args) == 2:
fn = getattr(args[0].asnumpy(), obj_str)
return Tensor(fn())
if isinstance(args[0], Tensor):
fn = getattr(args[0].asnumpy(), obj_str)
y = args[1].asnumpy() if isinstance(args[1], Tensor) else args[1]
else:
obj_str = "__r" + obj_str[2:]
fn = getattr(args[1].asnumpy(), obj_str)
y = args[0]
return Tensor(np.array(fn(y)))
def make_list(*xs): def make_list(*xs):
"""Implement `make_list`.""" """Implement `make_list`."""
......
...@@ -41,6 +41,35 @@ using TensorPtr = mindspore::tensor::TensorPtr; ...@@ -41,6 +41,35 @@ using TensorPtr = mindspore::tensor::TensorPtr;
using MetaTensor = mindspore::tensor::MetaTensor; using MetaTensor = mindspore::tensor::MetaTensor;
using MetaTensorPtr = mindspore::tensor::MetaTensorPtr; using MetaTensorPtr = mindspore::tensor::MetaTensorPtr;
FuncGraphPtr ConvertToBpropCut(const py::object &obj) {
std::vector<std::string> results = data_converter::GetObjKey(obj);
std::string obj_key = results[0];
py::function bprop_func = py::getattr(obj, CUSTOM_BPROP_NAME);
auto bprop_graph = std::make_shared<FuncGraph>();
std::vector<AnfNodePtr> outputs;
auto fake_bprop = std::make_shared<PrimitivePy>("bprop_cut", py::object());
fake_bprop->set_hook(bprop_func);
(void)fake_bprop->AddAttr(CUSTOM_BPROP_NAME, MakeValue(true));
outputs.push_back(NewValueNode(fake_bprop));
py::object code_obj = py::getattr(bprop_func, "__code__");
size_t inputs_num = py::cast<int>(py::getattr(code_obj, "co_argcount")) - 3;
for (size_t i = 0; i < inputs_num; ++i) {
auto param = bprop_graph->add_parameter();
outputs.push_back(param);
}
auto p1 = bprop_graph->add_parameter();
auto p2 = bprop_graph->add_parameter();
outputs.push_back(p1);
outputs.push_back(p2);
bprop_graph->set_output(bprop_graph->NewCNode(outputs));
data_converter::SetObjGraphValue(obj_key, bprop_graph);
return bprop_graph;
}
namespace { namespace {
bool ConvertTuple(const py::object &obj, ValuePtr *const data, bool use_signature) { bool ConvertTuple(const py::object &obj, ValuePtr *const data, bool use_signature) {
MS_LOG(DEBUG) << "Converting python tuple"; MS_LOG(DEBUG) << "Converting python tuple";
...@@ -231,35 +260,6 @@ bool ConvertSlice(const py::object &obj, ValuePtr *const data) { ...@@ -231,35 +260,6 @@ bool ConvertSlice(const py::object &obj, ValuePtr *const data) {
return true; return true;
} }
FuncGraphPtr ConvertToBpropCut(py::object obj) {
std::vector<std::string> results = data_converter::GetObjKey(obj);
std::string obj_key = results[0];
py::function bprop_func = py::getattr(obj, "bprop");
FuncGraphPtr bprop_graph = std::make_shared<FuncGraph>();
std::vector<AnfNodePtr> outputs;
auto fake_bprop = std::make_shared<PrimitivePy>("bprop_cut", py::object());
fake_bprop->set_hook(bprop_func);
(void)fake_bprop->AddAttr("bprop", MakeValue(true));
outputs.push_back(NewValueNode(fake_bprop));
py::object code_obj = py::getattr(bprop_func, "__code__");
size_t inputs_num = py::cast<int>(py::getattr(code_obj, "co_argcount")) - 3;
for (size_t i = 0; i < inputs_num; ++i) {
auto param = bprop_graph->add_parameter();
outputs.push_back(param);
}
auto p1 = bprop_graph->add_parameter();
auto p2 = bprop_graph->add_parameter();
outputs.push_back(p1);
outputs.push_back(p2);
bprop_graph->set_output(bprop_graph->NewCNode(outputs));
data_converter::SetObjGraphValue(obj_key, bprop_graph);
return bprop_graph;
}
bool ConvertCellObjToFuncGraph(py::object obj, ValuePtr *const data) { bool ConvertCellObjToFuncGraph(py::object obj, ValuePtr *const data) {
FuncGraphPtr func_graph = ConvertToFuncGraph(obj); FuncGraphPtr func_graph = ConvertToFuncGraph(obj);
if (func_graph == nullptr) { if (func_graph == nullptr) {
...@@ -267,7 +267,7 @@ bool ConvertCellObjToFuncGraph(py::object obj, ValuePtr *const data) { ...@@ -267,7 +267,7 @@ bool ConvertCellObjToFuncGraph(py::object obj, ValuePtr *const data) {
return false; return false;
} }
// if the cell object has specified bprop, it has user-defined bprop function parse and record it // if the cell object has specified bprop, it has user-defined bprop function parse and record it
if (py::hasattr(obj, "bprop")) { if (py::hasattr(obj, CUSTOM_BPROP_NAME)) {
FuncGraphPtr bprop_graph = nullptr; FuncGraphPtr bprop_graph = nullptr;
bool enable_bprop_debug = py::cast<bool>(py::getattr(obj, "bprop_debug")); bool enable_bprop_debug = py::cast<bool>(py::getattr(obj, "bprop_debug"));
if (enable_bprop_debug) { if (enable_bprop_debug) {
...@@ -276,7 +276,7 @@ bool ConvertCellObjToFuncGraph(py::object obj, ValuePtr *const data) { ...@@ -276,7 +276,7 @@ bool ConvertCellObjToFuncGraph(py::object obj, ValuePtr *const data) {
bprop_graph = ConvertToFuncGraph(obj, PYTHON_MOD_GET_BPROP_METHOD); bprop_graph = ConvertToFuncGraph(obj, PYTHON_MOD_GET_BPROP_METHOD);
} }
if (bprop_graph != nullptr) { if (bprop_graph != nullptr) {
(void)func_graph->transforms().insert(std::make_pair("bprop", FuncGraphTransform(bprop_graph))); (void)func_graph->transforms().insert(std::make_pair(CUSTOM_BPROP_NAME, FuncGraphTransform(bprop_graph)));
(void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(func_graph))); (void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(func_graph)));
func_graph->set_flags(FUNC_GRAPH_FLAG_DEFER_INLINE, true); func_graph->set_flags(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
} }
......
...@@ -51,6 +51,7 @@ void ClearObjectCache(); ...@@ -51,6 +51,7 @@ void ClearObjectCache();
} // namespace data_converter } // namespace data_converter
ClassPtr ParseDataClass(const py::object &cls_obj); ClassPtr ParseDataClass(const py::object &cls_obj);
FuncGraphPtr ConvertToBpropCut(const py::object &obj);
void CleanDataClassToClassMap(); void CleanDataClassToClassMap();
......
...@@ -109,6 +109,7 @@ const char PYTHON_EXTERN_MINDSPORE_FLAG[] = "_mindspore_flags"; ...@@ -109,6 +109,7 @@ const char PYTHON_EXTERN_MINDSPORE_FLAG[] = "_mindspore_flags";
// define the parse constant // define the parse constant
const int MAX_COMPARISON_OPS_SUPPORTED = 1; const int MAX_COMPARISON_OPS_SUPPORTED = 1;
const char CUSTOM_BPROP_NAME[] = "bprop";
// define the Namespace name // define the Namespace name
const char RESOLVE_NAMESPACE_NAME_AST[] = "Ast"; // for ast type namespace const char RESOLVE_NAMESPACE_NAME_AST[] = "Ast"; // for ast type namespace
......
...@@ -45,7 +45,7 @@ enum PynativeStatusCode { ...@@ -45,7 +45,7 @@ enum PynativeStatusCode {
PYNATIVE_UNKNOWN_STATE = 0XFF PYNATIVE_UNKNOWN_STATE = 0XFF
}; };
enum RunOpArgsEnum { PY_PRIM = 0, PY_NAME, PY_INPUTS, PY_INPUT_MASK, PY_ARGS_NUM }; enum RunOpArgsEnum { PY_PRIM = 0, PY_NAME, PY_INPUTS, PY_ARGS_NUM };
struct OpExecInfo { struct OpExecInfo {
PrimitivePyPtr py_primitive; PrimitivePyPtr py_primitive;
......
...@@ -110,9 +110,15 @@ py::object GetTupleObj(const py::object &obj) { ...@@ -110,9 +110,15 @@ py::object GetTupleObj(const py::object &obj) {
return obj_tuple; return obj_tuple;
} }
void ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple *out_args) { py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple *out_args) {
auto &py_args = *out_args; auto &py_args = *out_args;
py::tuple input_mask(args.size());
for (size_t i = 0; i < args.size(); ++i) { for (size_t i = 0; i < args.size(); ++i) {
if (py::hasattr(args[i], "__parameter__")) {
input_mask[i] = true;
} else {
input_mask[i] = false;
}
py_args[i] = GetTupleObj(args[i]); py_args[i] = GetTupleObj(args[i]);
} }
auto signature = prim->signatures(); auto signature = prim->signatures();
...@@ -121,7 +127,7 @@ void ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple * ...@@ -121,7 +127,7 @@ void ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple *
[](const Signature &sig) { return sig.dtype; }); [](const Signature &sig) { return sig.dtype; });
int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue); int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue);
if (dtypes.size() == 0 || static_cast<int>(dtypes.size()) == empty_dtype_count) { if (dtypes.size() == 0 || static_cast<int>(dtypes.size()) == empty_dtype_count) {
return; return input_mask;
} }
std::map<SignatureEnumDType, std::vector<size_t>> type_indexs; std::map<SignatureEnumDType, std::vector<size_t>> type_indexs;
for (size_t i = 0; i < dtypes.size(); ++i) { for (size_t i = 0; i < dtypes.size(); ++i) {
...@@ -160,6 +166,7 @@ void ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple * ...@@ -160,6 +166,7 @@ void ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple *
continue; continue;
} }
} }
return input_mask;
} }
void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecInfo *const op_exec_info) { void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecInfo *const op_exec_info) {
...@@ -167,7 +174,7 @@ void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecIn ...@@ -167,7 +174,7 @@ void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecIn
AbstractBasePtrList args_spec_list; AbstractBasePtrList args_spec_list;
for (size_t i = 0; i < size; i++) { for (size_t i = 0; i < size; i++) {
ValuePtr input_value = PyAttrValue(py_args[i]); ValuePtr input_value = PyAttrValue(py_args[i]);
if (input_value->isa<tensor::Tensor>()) { if (!py::hasattr(prim->GetPyObj(), "const_value") && input_value->isa<tensor::Tensor>()) {
args_spec_list.emplace_back(abstract::FromValueInside(input_value, true)); args_spec_list.emplace_back(abstract::FromValueInside(input_value, true));
} else { } else {
args_spec_list.emplace_back(abstract::FromValueInside(input_value, false)); args_spec_list.emplace_back(abstract::FromValueInside(input_value, false));
...@@ -179,7 +186,7 @@ void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecIn ...@@ -179,7 +186,7 @@ void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecIn
OpExecInfoPtr GenerateOpExecInfo(const py::args &args) { OpExecInfoPtr GenerateOpExecInfo(const py::args &args) {
if (args.size() != PY_ARGS_NUM) { if (args.size() != PY_ARGS_NUM) {
MS_LOG(ERROR) << "Four args are needed by RunOp"; MS_LOG(ERROR) << "Three args are needed by RunOp";
return nullptr; return nullptr;
} }
auto op_exec_info = std::make_shared<OpExecInfo>(); auto op_exec_info = std::make_shared<OpExecInfo>();
...@@ -195,14 +202,13 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args &args) { ...@@ -195,14 +202,13 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args &args) {
size_t input_num = a.size(); size_t input_num = a.size();
op_exec_info->op_inputs = py::tuple(input_num); op_exec_info->op_inputs = py::tuple(input_num);
ConvertInputs(prim, args[PY_INPUTS], &op_exec_info->op_inputs); op_exec_info->inputs_mask = ConvertInputs(prim, args[PY_INPUTS], &op_exec_info->op_inputs);
// use python infer method // use python infer method
if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) { if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) {
PynativeInfer(prim, op_exec_info->op_inputs, op_exec_info.get()); PynativeInfer(prim, op_exec_info->op_inputs, op_exec_info.get());
} }
op_exec_info->py_primitive = prim; op_exec_info->py_primitive = prim;
op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs"); op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs");
op_exec_info->inputs_mask = args[PY_INPUT_MASK];
if (op_exec_info->op_inputs.size() != op_exec_info->inputs_mask.size()) { if (op_exec_info->op_inputs.size() != op_exec_info->inputs_mask.size()) {
MS_LOG(ERROR) << "Op:" << op_exec_info->op_name << " inputs size not equal op_mask"; MS_LOG(ERROR) << "Op:" << op_exec_info->op_name << " inputs size not equal op_mask";
return nullptr; return nullptr;
...@@ -488,14 +494,14 @@ py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecIn ...@@ -488,14 +494,14 @@ py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecIn
return result; return result;
} }
AnfNodePtr PynativeExecutor::MakeCNode(const py::args &args, const py::tuple &out) { AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, const py::args &args, const py::tuple &out) {
if (!grad_flag_ || graph_info_map_.size() == 0) { if (!grad_flag_ || graph_info_map_.size() == 0) {
return nullptr; return nullptr;
} }
std::vector<AnfNodePtr> inputs; std::vector<AnfNodePtr> inputs;
auto prim = py::cast<PrimitivePyPtr>(args[PY_PRIM]); auto prim = op_exec_info->py_primitive;
inputs.push_back(NewValueNode(prim)); inputs.push_back(NewValueNode(prim));
py::tuple op_masks = args[PY_INPUT_MASK]; py::tuple op_masks = op_exec_info->inputs_mask;
py::list op_args = args[PY_INPUTS]; py::list op_args = args[PY_INPUTS];
AbstractBasePtrList args_spec_list; AbstractBasePtrList args_spec_list;
for (size_t i = 0; i < op_args.size(); i++) { for (size_t i = 0; i < op_args.size(); i++) {
...@@ -584,7 +590,7 @@ py::tuple RunOp(const py::args &args) { ...@@ -584,7 +590,7 @@ py::tuple RunOp(const py::args &args) {
return err_ret; return err_ret;
} }
auto node = PynativeExecutor::GetInstance()->MakeCNode(args, result); auto node = PynativeExecutor::GetInstance()->MakeCNode(op_exec_info, args, result);
if (node != nullptr) { if (node != nullptr) {
node->set_abstract(op_exec_info->abstract); node->set_abstract(op_exec_info->abstract);
MS_LOG(DEBUG) << "RunOp MakeCnode,new node is: " << node->DebugString(); MS_LOG(DEBUG) << "RunOp MakeCnode,new node is: " << node->DebugString();
...@@ -705,7 +711,7 @@ void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, c ...@@ -705,7 +711,7 @@ void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, c
} }
cell_graph_map_[cell_id] = curr_g_; cell_graph_map_[cell_id] = curr_g_;
auto out_id = GetId(out); auto out_id = GetId(out);
if (!graph_info_map_[curr_g_].obj_node_map.count(out_id)) { if (!graph_info_map_[curr_g_].obj_node_map.count(out_id) && !graph_info_map_[curr_g_].param_map.count(out_id)) {
// cell construct return x, y // cell construct return x, y
if (py::isinstance<py::tuple>(out)) { if (py::isinstance<py::tuple>(out)) {
std::vector<AnfNodePtr> args; std::vector<AnfNodePtr> args;
...@@ -727,12 +733,26 @@ void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, c ...@@ -727,12 +733,26 @@ void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, c
} }
} }
auto output_node = GetObjNode(out); AnfNodePtr output_node;
if (graph_info_map_[curr_g_].param_map.count(out_id)) {
output_node = graph_info_map_[curr_g_].param_map[out_id];
} else {
output_node = GetObjNode(out);
}
curr_g_->set_output(output_node); curr_g_->set_output(output_node);
std::vector<AnfNodePtr> inputs; std::vector<AnfNodePtr> inputs;
inputs.push_back(NewValueNode(curr_g_)); inputs.push_back(NewValueNode(curr_g_));
MS_LOG(DEBUG) << "Current graph" << curr_g_->output()->DebugString(); MS_LOG(DEBUG) << "Current graph" << curr_g_->output()->DebugString();
resource_->manager()->AddFuncGraph(curr_g_); resource_->manager()->AddFuncGraph(curr_g_);
// custom bprop debug
if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
MS_LOG(DEBUG) << "Use cell custom bprop function.";
FuncGraphPtr bprop_graph = parse::ConvertToBpropCut(cell);
if (bprop_graph != nullptr) {
(void)curr_g_->transforms().insert(std::make_pair(parse::CUSTOM_BPROP_NAME, FuncGraphTransform(bprop_graph)));
(void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(curr_g_)));
}
}
auto newfg = ad::Grad(curr_g_, resource_, curr_g_ == top_g_); auto newfg = ad::Grad(curr_g_, resource_, curr_g_ == top_g_);
if (curr_g_ != top_g_) { if (curr_g_ != top_g_) {
Popp(); Popp();
......
...@@ -44,7 +44,7 @@ py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat ...@@ -44,7 +44,7 @@ py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
py::tuple RunOp(const py::args &args); py::tuple RunOp(const py::args &args);
void ConvertInputs(const PrimitivePyPtr &prim, const py::list &py_args, py::tuple *out_args); py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &py_args, py::tuple *out_args);
void ClearPyNativeSession(); void ClearPyNativeSession();
...@@ -83,7 +83,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> { ...@@ -83,7 +83,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
void set_obj_node_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node, int index) { void set_obj_node_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node, int index) {
graph_info_map_[g].obj_node_map[obj] = std::make_pair(node, index); graph_info_map_[g].obj_node_map[obj] = std::make_pair(node, index);
} }
AnfNodePtr MakeCNode(const py::args &args, const py::tuple &out); AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, const py::args &args, const py::tuple &out);
py::object Run(const py::tuple &args, const py::object &phase); py::object Run(const py::tuple &args, const py::object &phase);
void Pushp(); void Pushp();
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
"""Registry the relation.""" """Registry the relation."""
from collections import UserDict from collections import UserDict
from .. import context
class Registry(UserDict): class Registry(UserDict):
...@@ -27,9 +28,16 @@ class Registry(UserDict): ...@@ -27,9 +28,16 @@ class Registry(UserDict):
def get(self, obj_str): def get(self, obj_str):
"""Get the value by str.""" """Get the value by str."""
if isinstance(obj_str, str): if not isinstance(obj_str, str):
raise TypeError("key for tensor registry must be string.")
if context.get_context("enable_ge"):
def wrap(*args):
new_args = list(args)
new_args.append(obj_str)
return self["vm_compare"](*new_args)
obj = wrap
else:
obj = self[obj_str] obj = self[obj_str]
return obj return obj
tensor_operator_registry = Registry() tensor_operator_registry = Registry()
...@@ -19,7 +19,6 @@ from .._c_expression import Tensor as Tensor_ ...@@ -19,7 +19,6 @@ from .._c_expression import Tensor as Tensor_
from .._c_expression import MetaTensor from .._c_expression import MetaTensor
from .._checkparam import check_type, check_typename from .._checkparam import check_type, check_typename
from . import dtype as mstype from . import dtype as mstype
from .. import context
from ._register_for_tensor import tensor_operator_registry from ._register_for_tensor import tensor_operator_registry
__all__ = ['Tensor', 'MetaTensor'] __all__ = ['Tensor', 'MetaTensor']
...@@ -76,17 +75,19 @@ class Tensor(Tensor_): ...@@ -76,17 +75,19 @@ class Tensor(Tensor_):
return out return out
def __eq__(self, other): def __eq__(self, other):
if not isinstance(other, Tensor): if not isinstance(other, (int, float, Tensor)):
return False return False
# The GE backend don't support single `Equal` operator execution.
# bool type is not supported for `Equal` operator in backend. # bool type is not supported for `Equal` operator in backend.
if context.get_context("enable_ge") or self.dtype == mstype.bool_ or other.dtype == mstype.bool_: if self.dtype == mstype.bool_ or (isinstance(other, Tensor) and other.dtype == mstype.bool_):
return Tensor(np.array(self.asnumpy() == other.asnumpy())) return Tensor(np.array(self.asnumpy() == other.asnumpy()))
return tensor_operator_registry.get('__eq__')(self, other) return tensor_operator_registry.get('__eq__')(self, other)
def __ne__(self, other): def __ne__(self, other):
if not isinstance(other, Tensor): if not isinstance(other, (int, float, Tensor)):
return True return True
# bool type is not supported for `NotEqual` operator in backend.
if self.dtype == mstype.bool_ or (isinstance(other, Tensor) and other.dtype == mstype.bool_):
return Tensor(np.array(self.asnumpy() != other.asnumpy()))
return tensor_operator_registry.get('__ne__')(self, other) return tensor_operator_registry.get('__ne__')(self, other)
def __hash__(self): def __hash__(self):
...@@ -105,7 +106,7 @@ class Tensor(Tensor_): ...@@ -105,7 +106,7 @@ class Tensor(Tensor_):
return out return out
def __radd__(self, other): def __radd__(self, other):
out = tensor_operator_registry.get('__add__')(other, self) out = tensor_operator_registry.get('__add__')(self, other)
return out return out
def __imul__(self, other): def __imul__(self, other):
...@@ -113,15 +114,15 @@ class Tensor(Tensor_): ...@@ -113,15 +114,15 @@ class Tensor(Tensor_):
return out return out
def __rmul__(self, other): def __rmul__(self, other):
out = tensor_operator_registry.get('__mul__')(other, self) out = tensor_operator_registry.get('__mul__')(self, other)
return out return out
def __truediv__(self, other): def __truediv__(self, other):
out = tensor_operator_registry.get('__div__')(self, other) out = tensor_operator_registry.get('__truediv__')(self, other)
return out return out
def __rtruediv__(self, other): def __rtruediv__(self, other):
out = tensor_operator_registry.get('__div__')(other, self) out = tensor_operator_registry.get('__truediv__')(other, self)
return out return out
def __sub__(self, other): def __sub__(self, other):
...@@ -160,7 +161,7 @@ class Tensor(Tensor_): ...@@ -160,7 +161,7 @@ class Tensor(Tensor_):
return out return out
def __len__(self): def __len__(self):
out = tensor_operator_registry.get('__shape__')(self) out = tensor_operator_registry.get('shape')(self)
if not out: if not out:
return 1 return 1
return out[0] return out[0]
......
...@@ -819,4 +819,4 @@ class Cell: ...@@ -819,4 +819,4 @@ class Cell:
""" """
self._backward_hook = HookBackward(fn, self.cls_name + "(" + str(id(self)) + ")") self._backward_hook = HookBackward(fn, self.cls_name + "(" + str(id(self)) + ")")
self._enable_hook = True self.enable_hook = True
...@@ -140,6 +140,11 @@ class SequentialCell(Cell): ...@@ -140,6 +140,11 @@ class SequentialCell(Cell):
def __len__(self): def __len__(self):
return len(self._cells) return len(self._cells)
def set_grad(self, flag=True):
self.requires_grad = flag
for cell in self._cells.values():
cell.set_grad(flag)
def construct(self, input_data): def construct(self, input_data):
for cell in self.cell_list: for cell in self.cell_list:
input_data = cell(input_data) input_data = cell(input_data)
...@@ -246,5 +251,10 @@ class CellList(_CellListBase, Cell): ...@@ -246,5 +251,10 @@ class CellList(_CellListBase, Cell):
self._cells[str(len(self))] = cell self._cells[str(len(self))] = cell
return self return self
def set_grad(self, flag=True):
self.requires_grad = flag
for cell in self._cells.values():
cell.set_grad(flag)
def construct(self, *inputs): def construct(self, *inputs):
raise NotImplementedError raise NotImplementedError
...@@ -112,7 +112,7 @@ class GradOperation(GradOperation_): ...@@ -112,7 +112,7 @@ class GradOperation(GradOperation_):
grad_ = GradOperation('grad', self.get_all, self.get_by_list, self.sens_param) grad_ = GradOperation('grad', self.get_all, self.get_by_list, self.sens_param)
if self.grad_fn is None or self.fn != fn: if self.grad_fn is None or self.fn != fn:
if self.get_by_list: if self.get_by_list:
if context.get_context("mode") == context.GRAPH_MODE or fn.bprop_debug: if context.get_context("mode") == context.GRAPH_MODE:
@ms_function(obj=fn) @ms_function(obj=fn)
def after_grad(*args): def after_grad(*args):
return grad_(fn, weights)(*args) return grad_(fn, weights)(*args)
......
...@@ -21,6 +21,7 @@ from mindspore.common._register_for_tensor import tensor_operator_registry ...@@ -21,6 +21,7 @@ from mindspore.common._register_for_tensor import tensor_operator_registry
from .primitive import Primitive from .primitive import Primitive
from . import operations as P from . import operations as P
from .operations import _grad_ops from .operations import _grad_ops
from .._extends import builtin_operations as BP
typeof = Primitive('typeof') typeof = Primitive('typeof')
hastype = Primitive('hastype') hastype = Primitive('hastype')
...@@ -155,7 +156,7 @@ stop_gradient = Primitive("stop_gradient") ...@@ -155,7 +156,7 @@ stop_gradient = Primitive("stop_gradient")
tensor_operator_registry.register('__add__', tensor_add) tensor_operator_registry.register('__add__', tensor_add)
tensor_operator_registry.register('__sub__', tensor_sub) tensor_operator_registry.register('__sub__', tensor_sub)
tensor_operator_registry.register('__mul__', tensor_mul) tensor_operator_registry.register('__mul__', tensor_mul)
tensor_operator_registry.register('__div__', tensor_div) tensor_operator_registry.register('__truediv__', tensor_div)
#ms cannot support Tensor(True) compare #ms cannot support Tensor(True) compare
tensor_operator_registry.register('__eq__', equal) tensor_operator_registry.register('__eq__', equal)
tensor_operator_registry.register('__ne__', not_equal) tensor_operator_registry.register('__ne__', not_equal)
...@@ -164,4 +165,6 @@ tensor_operator_registry.register('__lt__', tensor_lt) ...@@ -164,4 +165,6 @@ tensor_operator_registry.register('__lt__', tensor_lt)
tensor_operator_registry.register('__le__', tensor_le) tensor_operator_registry.register('__le__', tensor_le)
tensor_operator_registry.register('__gt__', tensor_gt) tensor_operator_registry.register('__gt__', tensor_gt)
tensor_operator_registry.register('__ge__', tensor_ge) tensor_operator_registry.register('__ge__', tensor_ge)
tensor_operator_registry.register('__shape__', shape) tensor_operator_registry.register('shape', shape)
#support GE backend for no compare operators
tensor_operator_registry.register('vm_compare', BP.vm_compare)
...@@ -933,6 +933,8 @@ class TupleToArray(PrimitiveWithInfer): ...@@ -933,6 +933,8 @@ class TupleToArray(PrimitiveWithInfer):
args = list() args = list()
if isinstance(x, range): if isinstance(x, range):
args.append(tuple(x)) args.append(tuple(x))
else:
args.append(x)
return _run_op(self, self.name, args) return _run_op(self, self.name, args)
......
...@@ -341,13 +341,7 @@ def constexpr(fn=None, get_instance=True, name=None): ...@@ -341,13 +341,7 @@ def constexpr(fn=None, get_instance=True, name=None):
@_wrap_func @_wrap_func
def _run_op(obj, op_name, args): def _run_op(obj, op_name, args):
"""Single op execution function supported by ge in PyNative mode.""" """Single op execution function supported by ge in PyNative mode."""
op_mask = [0] * len(args) output = real_run_op(obj, op_name, args)
op_inputs = []
for i, arg in enumerate(args):
if hasattr(arg, '__parameter__'):
op_mask[i] = 1
op_inputs.append(arg)
output = real_run_op(obj, op_name, args, tuple(op_mask))
if not output: if not output:
raise RuntimeError("Pynative run op %s failed!" % op_name) raise RuntimeError("Pynative run op %s failed!" % op_name)
if len(output) == 1: if len(output) == 1:
......
...@@ -63,8 +63,7 @@ OpExecInfoPtr ConstructOpExecInfo() { ...@@ -63,8 +63,7 @@ OpExecInfoPtr ConstructOpExecInfo() {
auto conv_obj = prim::GetPythonOps("conv2d_prim", "gtest_input.pynative"); auto conv_obj = prim::GetPythonOps("conv2d_prim", "gtest_input.pynative");
py::none py_none; py::none py_none;
py::tuple op_mask = py::make_tuple(0, 1); return GenerateOpExecInfo(py::make_tuple(conv_obj, op_name, op_inputs));
return GenerateOpExecInfo(py::make_tuple(conv_obj, op_name, op_inputs, op_mask));
} }
TEST_F(TestPynativeExecute, TestRunOpInVM) { TEST_F(TestPynativeExecute, TestRunOpInVM) {
...@@ -79,7 +78,7 @@ TEST_F(TestPynativeExecute, TestRunOp) { ...@@ -79,7 +78,7 @@ TEST_F(TestPynativeExecute, TestRunOp) {
py::none py_none; py::none py_none;
auto op_exec_info_ptr = ConstructOpExecInfo(); auto op_exec_info_ptr = ConstructOpExecInfo();
py::tuple outputs = pynative::RunOp(py::make_tuple(op_exec_info_ptr->py_primitive, op_exec_info_ptr->op_name, py::tuple outputs = pynative::RunOp(py::make_tuple(op_exec_info_ptr->py_primitive, op_exec_info_ptr->op_name,
op_exec_info_ptr->op_inputs, op_exec_info_ptr->inputs_mask)); op_exec_info_ptr->op_inputs));
if (outputs.size() == 0) { if (outputs.size() == 0) {
FAIL(); FAIL();
} else { } else {
......
...@@ -452,5 +452,5 @@ def test_tensor_operation(): ...@@ -452,5 +452,5 @@ def test_tensor_operation():
assert np.all(res.asnumpy() == np.ones((3, 3)) * 2) assert np.all(res.asnumpy() == np.ones((3, 3)) * 2)
res = 8 / x res = 8 / x
assert np.all(res.asnumpy() == np.ones((3, 3)) * 2) assert np.all(res.asnumpy() == np.ones((3, 3)) * 2)
with pytest.raises(TypeError): with pytest.raises(ValueError):
res = x * (2, 3) res = x * (2, 3)
...@@ -8,6 +8,9 @@ from mindspore.nn import WithLossCell, Momentum ...@@ -8,6 +8,9 @@ from mindspore.nn import WithLossCell, Momentum
from mindspore.ops import composite as C from mindspore.ops import composite as C
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
cell_hook_done = False
var_hook_done = False
cell_bprop_done = False
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
...@@ -32,15 +35,35 @@ def weight_variable(): ...@@ -32,15 +35,35 @@ def weight_variable():
def cell_hook_function(cell_id, grad_input, grad_output): def cell_hook_function(cell_id, grad_input, grad_output):
print(cell_id) print(cell_id)
global cell_hook_done
cell_hook_done = True
assert (grad_output[0].asnumpy().shape == (32, 6, 14, 14)) assert (grad_output[0].asnumpy().shape == (32, 6, 14, 14))
assert (grad_input[0].asnumpy().shape == (32, 16, 10, 10)) assert (grad_input[0].asnumpy().shape == (32, 16, 10, 10))
def var_hook_function(grad_out): def var_hook_function(grad_out):
print("grad:", grad_out) print("grad:", grad_out)
global var_hook_done
var_hook_done = True
assert (grad_out[0].asnumpy().shape == (32, 120)) assert (grad_out[0].asnumpy().shape == (32, 120))
class Block(nn.Cell):
def __init__(self):
super(Block, self).__init__()
self.relu = nn.ReLU()
def construct(self, x):
x = self.relu(x)
return x
def bprop(self, x, out, dout):
global cell_bprop_done
cell_bprop_done = True
grad = out.asnumpy() * dout.asnumpy()
grad = Tensor(grad)
return (grad,)
class LeNet5(nn.Cell): class LeNet5(nn.Cell):
""" """
Lenet network Lenet network
...@@ -59,6 +82,7 @@ class LeNet5(nn.Cell): ...@@ -59,6 +82,7 @@ class LeNet5(nn.Cell):
self.conv1 = conv(1, 6, 5) self.conv1 = conv(1, 6, 5)
self.conv2 = conv(6, 16, 5) self.conv2 = conv(6, 16, 5)
self.conv2.register_backward_hook(cell_hook_function) self.conv2.register_backward_hook(cell_hook_function)
self.block = Block()
self.fc1 = fc_with_initialize(16 * 5 * 5, 120) self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
self.fc2 = fc_with_initialize(120, 84) self.fc2 = fc_with_initialize(120, 84)
self.fc3 = fc_with_initialize(84, self.num_class) self.fc3 = fc_with_initialize(84, self.num_class)
...@@ -72,7 +96,7 @@ class LeNet5(nn.Cell): ...@@ -72,7 +96,7 @@ class LeNet5(nn.Cell):
x = self.relu(x) x = self.relu(x)
x = self.max_pool2d(x) x = self.max_pool2d(x)
x = self.conv2(x) x = self.conv2(x)
x = self.relu(x) x = self.block(x)
x = self.max_pool2d(x) x = self.max_pool2d(x)
x = self.reshape(x, (self.batch_size, -1)) x = self.reshape(x, (self.batch_size, -1))
x = self.fc1(x) x = self.fc1(x)
...@@ -110,6 +134,9 @@ def test_hook(): ...@@ -110,6 +134,9 @@ def test_hook():
loss_output = criterion(output, label) loss_output = criterion(output, label)
grads = train_network(input_data, label) grads = train_network(input_data, label)
success = optimizer(grads) success = optimizer(grads)
assert cell_hook_done
assert var_hook_done
assert cell_bprop_done
print(loss_output.asnumpy().shape) print(loss_output.asnumpy().shape)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册