提交 38436f92 编写于 作者: K kingfo

move hook function to primtivePy class

上级 444d9484
......@@ -13,7 +13,6 @@
# limitations under the License.
# ============================================================================
"""builtin_operations"""
import functools
import numpy as np
from mindspore.common.tensor import Tensor
from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype
......@@ -124,17 +123,8 @@ def list_len(x):
"""Implement `list_len`."""
return len(x)
# only used in PyNative mode
def partial(*args):
"""Implement `partial`."""
func = args[0].__call__
partial_func = functools.partial(func, *args[1:])
return partial_func
# only used in PyNative mode
def depend(value, expr):
def Depend(value, expr):
"""Implement `Depend`."""
return value
# only used in PyNative mode
......
......@@ -49,6 +49,8 @@ class PrimitivePy : public Primitive {
void AddPyAttr(const py::str &name, const py::object &obj);
py::dict GetAttrDict();
void set_hook(const py::function &hook) { hook_ = hook; }
py::function hook() const { return hook_; }
const bool parse_info_ = true;
const py::object &GetPyObj() const { return python_obj_; }
......@@ -56,6 +58,7 @@ class PrimitivePy : public Primitive {
private:
py::object python_obj_;
py::function hook_;
std::vector<Signature> signatures_;
};
......
......@@ -89,9 +89,6 @@ class Primitive : public Named {
return iter == attrs_.cend() ? nullptr : iter->second;
}
void set_hook(const py::function &hook) { hook_ = hook; }
py::function hook() const { return hook_; }
const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; }
std::unordered_map<std::string, ValuePtr> &evaluate_added_attrs() { return evaluate_added_attrs_; }
......@@ -124,7 +121,6 @@ class Primitive : public Named {
private:
std::string instance_name_;
py::function hook_;
bool is_base_;
bool has_signature_;
PrimType prim_type_;
......
......@@ -220,7 +220,7 @@ const PrimitivePtr kPrimBpropCut = std::make_shared<Primitive>("bprop_cut");
// Other miscellaneous
const PrimitivePtr kPrimIdentity = std::make_shared<Primitive>("identity");
const PrimitivePtr kPrimPartial = std::make_shared<Primitive>("partial");
const PrimitivePtr kPrimPartial = std::make_shared<Primitive>("Partial");
const PrimitivePtr kPrimJ = std::make_shared<Primitive>("J");
const PrimitivePtr kPrimEnvSetItem = std::make_shared<Primitive>("env_setitem");
const PrimitivePtr kPrimEnvGetItem = std::make_shared<Primitive>("env_getitem");
......@@ -237,7 +237,7 @@ const PrimitivePtr kPrimCheckBprop = std::make_shared<Primitive>("CheckBprop");
const PrimitivePtr kPrimPrint = std::make_shared<Primitive>("Print");
const PrimitivePtr kPrimMakeRef = std::make_shared<Primitive>("make_ref");
const PrimitivePtr kPrimDepend = std::make_shared<Primitive>("depend");
const PrimitivePtr kPrimDepend = std::make_shared<Primitive>("Depend");
const PrimitivePtr kPrimStateSetItem = std::make_shared<Primitive>("state_setitem");
const PrimitivePtr kPrimBroadcastGradientArgs = std::make_shared<Primitive>("BroadcastGradientArgs");
......
......@@ -238,8 +238,12 @@ FuncGraphPtr KPrim::BpropCut(const ValueNodePtr &value_node, const pipeline::Res
auto func_graph = std::make_shared<FuncGraph>();
std::vector<AnfNodePtr> outputs;
auto bprop_cut = std::make_shared<Primitive>("bprop_cut");
bprop_cut->set_hook(prim->hook());
auto bprop_cut = std::make_shared<PrimitivePy>("bprop_cut", py::object());
if (!prim->is_base()) {
PrimitivePyPtr prim_py = dyn_cast<PrimitivePy>(prim);
bprop_cut->set_hook(prim_py->hook());
}
auto cell_id = GetValue<std::string>(prim->GetAttr("cell_id"));
if (cell_id != "") {
(void)bprop_cut->AddAttr("cell_hook", MakeValue(true));
......
......@@ -72,7 +72,7 @@ constexpr char OP[] = "op";
constexpr char IDENTITY_INFO[] = "identity_info";
constexpr char DIVISOR[] = "divisor";
constexpr char NONE[] = "None";
constexpr char DEPEND[] = "depend";
constexpr char DEPEND[] = "Depend";
constexpr char BATCH_PARALLEL[] = "BatchParallel";
constexpr char ACTIVATION_TYPE[] = "activation_type";
......
......@@ -217,7 +217,7 @@ FuncGraphPtr ConvertToBpropCut(py::object obj) {
FuncGraphPtr bprop_graph = std::make_shared<FuncGraph>();
std::vector<AnfNodePtr> outputs;
auto fake_bprop = std::make_shared<Primitive>("bprop_cut");
auto fake_bprop = std::make_shared<PrimitivePy>("bprop_cut", py::object());
fake_bprop->set_hook(bprop_func);
(void)fake_bprop->AddAttr("bprop", MakeValue(true));
outputs.push_back(NewValueNode(fake_bprop));
......
......@@ -59,7 +59,7 @@ struct OpExecInfo {
using OpExecInfoPtr = std::shared_ptr<OpExecInfo>;
OpExecInfoPtr GenerateOpExecInfo(const py::args &args);
const std::set<std::string> ignore_infer_prim = {"partial", "make_ref"};
const std::set<std::string> ignore_infer_prim = {"make_ref"};
} // namespace pynative
} // namespace mindspore
......
......@@ -53,7 +53,7 @@
const char SINGLE_OP_GRAPH[] = "single_op_graph";
// primitive unable to infer value for constant input in PyNative mode
const std::set<std::string> vm_operators = {"partial", "depend", "make_ref", "HookBackward"};
const std::set<std::string> vm_operators = {"make_ref", "HookBackward"};
namespace mindspore {
namespace pynative {
......
......@@ -959,8 +959,8 @@ void DfGraphConvertor::TraceOutput(const AnfNodePtr node) {
for (unsigned int i = 1; i < c->inputs().size(); i++) {
TraceOutput(c->input(i));
}
} else if (name == "depend") {
if (c->inputs().size() < 3) { // "depend" primitive have 3 inputs
} else if (name == "Depend") {
if (c->inputs().size() < 3) { // "Depend" primitive have 3 inputs
MS_LOG(EXCEPTION) << "length of inputs is " << c->inputs().size() << ", which is less than 3";
}
TraceOutput(c->input(1));
......@@ -1183,7 +1183,7 @@ void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node
auto &inputs = node->inputs();
for (size_t i = 1; i < inputs.size(); i++) {
auto pred = inputs[i];
while (pred->isa<CNode>() && GetCNodeFuncName(pred->cast<CNodePtr>()) == "depend") {
while (pred->isa<CNode>() && GetCNodeFuncName(pred->cast<CNodePtr>()) == "Depend") {
pred = pred->cast<CNodePtr>()->input(1);
}
// skip the None input
......@@ -1362,7 +1362,7 @@ AnfNodePtr DfGraphConvertor::TraceTupleGetItem(const CNodePtr &node, unsigned in
AnfNodePtr DfGraphConvertor::TraceDepend(const CNodePtr &node) {
auto cnode = node->cast<CNodePtr>();
if (cnode->inputs().size() < 3) { // "depend" primitive have 3 inputs
if (cnode->inputs().size() < 3) { // "Depend" primitive have 3 inputs
MS_LOG(EXCEPTION) << "length of inputs of depend is less than 3";
}
return cnode->inputs()[1];
......@@ -1483,7 +1483,7 @@ AnfNodePtr DfGraphConvertor::GetRealOpNode(AnfNodePtr node) {
// depend apply inputs: depend,output,depended_node
if (IsPrimitiveCNode(node, prim::kPrimDepend)) {
auto depend_inputs = node->cast<CNodePtr>()->inputs();
if (depend_inputs.size() != 3) { // "depend" primitive have 3 inputs
if (depend_inputs.size() != 3) { // "Depend" primitive have 3 inputs
MS_LOG(ERROR) << "depend input items not correct";
error_ = FAILED;
return node;
......@@ -1700,7 +1700,7 @@ void DfGraphConvertor::ConvertControlDependNode(const CNodePtr node) {
bool DfGraphConvertor::CheckCNode(const std::string &name, const CNodePtr node) {
// ignore apply node of return
if (name == "return" || name == "depend") {
if (name == "return" || name == "Depend") {
return false;
}
......
......@@ -585,8 +585,8 @@ void FinalVM::InstPushPrim(const VectorRef &args) {
return;
}
VectorRef tuple;
auto prim = utils::cast<PrimitivePtr>(args[0]);
VectorRef tuple;
for (size_t i = 1; i < args.size(); ++i) {
auto index = utils::cast<int>(args[i]);
tuple.push_back(Ref(index));
......@@ -618,6 +618,7 @@ void FinalVM::SyncData(const py::object &arg) {
BaseRef FinalVM::RunHook(const PrimitivePtr &prim, const VectorRef &args) {
MS_LOG(DEBUG) << "input for operation:";
auto prim_py = dyn_cast<PrimitivePy>(prim);
std::size_t args_size = args.size();
auto py_args = py::tuple(args_size);
size_t i = 0;
......@@ -631,7 +632,7 @@ BaseRef FinalVM::RunHook(const PrimitivePtr &prim, const VectorRef &args) {
bool is_bprop = prim->HasAttr("bprop");
if (is_bprop) {
SyncData(py_args);
py::function fn_bprop = prim->hook();
py::function fn_bprop = prim_py->hook();
obj = fn_bprop(*py_args);
return obj;
}
......@@ -647,7 +648,7 @@ BaseRef FinalVM::RunHook(const PrimitivePtr &prim, const VectorRef &args) {
hook_args[0] = cell_id;
hook_args[1] = py::make_tuple(_hook_grad[cell_id]);
hook_args[2] = py::make_tuple(py_args[2]);
py::function fn_hook = prim->hook();
py::function fn_hook = prim_py->hook();
obj = fn_hook(*hook_args);
if (py::isinstance<py::none>(obj)) {
obj = py_args[2];
......@@ -659,7 +660,7 @@ BaseRef FinalVM::RunHook(const PrimitivePtr &prim, const VectorRef &args) {
}
} else {
// Hook operator for execute variable hook function
py::function fn_hook = prim->hook();
py::function fn_hook = prim_py->hook();
obj = fn_hook(py::make_tuple(py_args[2]));
if (py::isinstance<py::none>(obj)) {
obj = py_args[2];
......
......@@ -78,6 +78,8 @@ class Tensor(Tensor_):
def __eq__(self, other):
if not isinstance(other, Tensor):
return False
# The GE backend don't support single `Equal` operator execution.
# bool type is not supported for `Equal` operator in backend.
if context.get_context("enable_ge") or self.dtype() == mstype.bool_ or other.dtype() == mstype.bool_:
return Tensor(np.array(self.asnumpy() == other.asnumpy()))
return tensor_operator_registry.get('__eq__')(self, other)
......
......@@ -195,7 +195,7 @@ def bprop_array_reduce(fn, x, shp, out, dout):
return F.distribute(dout, F.shape(x)), C.zeros_like(shp)
@bprops.register("depend")
@bprops.register("Depend")
def bprop_depend(x, y, out, dout):
"""Backpropagator for primitive `depend`."""
return dout, C.zeros_like(y)
......@@ -236,7 +236,6 @@ def bprop_control_depend(x, y, out, dout):
"""Backpropagator for primitive `Control_depend`."""
return C.zeros_like(x), C.zeros_like(y)
@bprops.register("switch")
def bprop_switch(cond, tb, fb, out, dout):
"""Backpropagator for primitive `switch`."""
......
......@@ -22,7 +22,7 @@ from mindspore import context
from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, MultitypeFuncGraph_, Tail_, TensorSlice_, \
TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_
from ...common import dtype as mstype
from ...common.api import ms_function, _pynative_exec
from ...common.api import ms_function, _pynative_exec, _wrap_func
from .. import functional as F
from ...common.parameter import Parameter
......@@ -117,6 +117,7 @@ class GradOperation(GradOperation_):
def after_grad(*args):
return grad_(fn, weights)(*args)
else:
@_wrap_func
def after_grad(*args):
if fn.is_run and not fn.requires_grad:
raise ValueError("obj must set_grad.")
......
......@@ -77,6 +77,9 @@ gather_nd = P.GatherNd()
scatter_update = P.ScatterUpdate()
scatter_nd_update = P.ScatterNdUpdate()
pack = P.Pack()
partial = P.Partial()
# depend: mount a node to another node
depend = P.Depend()
tuple_setitem = Primitive('tuple_setitem')
......@@ -131,12 +134,9 @@ mixed_precision_cast = Primitive("mixed_precision_cast")
broadcast_gradient_args = Primitive('BroadcastGradientArgs')
dot = Primitive('dot')
array_reduce = Primitive('array_reduce')
partial = Primitive('partial')
zeros_like = P.ZerosLike()
identity = Primitive('identity')
distribute = Primitive('distribute')
# depend: mount a node to another node
depend = Primitive('depend')
embed = Primitive('embed')
ref_to_embed = _grad_ops.RefToEmbed()
env_setitem = Primitive('env_setitem')
......
......@@ -74,7 +74,7 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm,
ApplyProximalAdagrad, SparseApplyProximalAdagrad,
ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell)
from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode,
CheckValid, MakeRefKey, CheckBprop, ConfusionMatrix)
CheckValid, MakeRefKey, Partial, Depend, CheckBprop, ConfusionMatrix)
from . import _quant_ops
from ._quant_ops import *
from .thor_ops import *
......@@ -213,6 +213,8 @@ __all__ = [
'NMSWithMask',
'IOU',
'MakeRefKey',
'Partial',
'Depend',
'AvgPool',
# Back Primitive
'Equal',
......
......@@ -14,6 +14,7 @@
# ============================================================================
"""Other operators."""
import functools
from ..._c_expression import signature_rw as sig_rw
from ..._c_expression import signature_kind as sig_kind
from ..._c_expression import signature_dtype as sig_dtype
......@@ -304,6 +305,46 @@ class MakeRefKey(Primitive):
pass
class Partial(Primitive):
"""
Make a partial function instance, used for pynative mode.
Inputs:
- **args** (Union[FunctionType, Tensor]) - The function and bind arguments.
Outputs:
FunctionType, partial function binded with arguments.
"""
@prim_attr_register
def __init__(self):
pass
def __call__(self, *args):
func = args[0].__call__
partial_func = functools.partial(func, *args[1:])
return partial_func
class Depend(Primitive):
"""
Depend is used for process side-effect operations.
Inputs:
- **value** (Tensor) - the real value to return for depend operator.
- **expr** (Expression) - the expression to execute with no outputs.
Outputs:
Tensor, the value passed by last operator.
"""
@prim_attr_register
def __init__(self):
pass
def __call__(self, value, expr):
return value
class CheckBprop(PrimitiveWithInfer):
"""
Checks whether data type and shape of corresponding element from tuple x and y are the same.
......
......@@ -341,7 +341,7 @@ TEST_F(TestOps, ResolveTest) {
}
TEST_F(TestOps, PartialTest) {
auto prim = std::make_shared<Primitive>("partial");
auto prim = std::make_shared<Primitive>("Partial");
ASSERT_EQ(prim->name(), kPrimPartial->name());
}
......
......@@ -636,7 +636,7 @@ def test_tuple_get_set_item(tag):
def test_partial(tag):
""" test_partial """
fns = FnDict()
partail = Primitive('partial')
partail = P.Partial()
def f(x, y):
return scalar_add(x, y)
......@@ -655,7 +655,7 @@ def test_partial(tag):
def test_replace_applicator(tag):
""" test_replace_applicator """
fns = FnDict()
partail = Primitive('partial')
partail = P.Partial()
def app1(x, y):
return scalar_add(x, y)
......
......@@ -22,7 +22,7 @@ four2five = Primitive('Four2Five')
five2four = Primitive('Five2Four')
transdata = Primitive("TransData")
cast = Primitive('Cast')
depend = Primitive('depend')
depend = P.Depend()
class FnDict:
......
......@@ -16,13 +16,13 @@ import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor
from mindspore.ops import Primitive
from mindspore.ops import operations as P
from mindspore.ops import functional as F
AssignSub = P.AssignSub()
Mul = P.Mul()
Sub = P.Sub()
make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem')
depend = Primitive('depend')
BatchNorm = P.BatchNorm()
Cast = P.Cast()
BNTrainingReduce = Primitive('BNTrainingReduce')
......@@ -54,8 +54,8 @@ def test_fused_batch_norm_fusion(tag):
mul1 = Mul(sub1, constant1)
assign_sub0 = AssignSub(var0, mul0)
assign_sub1 = AssignSub(var1, mul1)
depend0 = depend(tuple_getitem(batch_norm, 0), assign_sub0)
depend1 = depend(depend0, assign_sub1)
depend0 = F.depend(tuple_getitem(batch_norm, 0), assign_sub0)
depend1 = F.depend(depend0, assign_sub1)
outputs = make_tuple(depend1, tuple_getitem(batch_norm, 3), tuple_getitem(batch_norm, 4))
output = tuple_getitem(outputs, 0)
return output
......@@ -69,8 +69,8 @@ def test_fused_batch_norm_fusion(tag):
mul1 = Mul(sub1, constant1)
assign_sub0 = AssignSub(var0, Cast(mul0, mstype.float32))
assign_sub1 = AssignSub(var1, Cast(mul1, mstype.float32))
depend0 = depend(tuple_getitem(batch_norm, 0), assign_sub0)
depend1 = depend(depend0, assign_sub1)
depend0 = F.depend(tuple_getitem(batch_norm, 0), assign_sub0)
depend1 = F.depend(depend0, assign_sub1)
outputs = make_tuple(depend1, tuple_getitem(batch_norm, 3), tuple_getitem(batch_norm, 4))
output = tuple_getitem(outputs, 0)
return output
......@@ -84,8 +84,8 @@ def test_fused_batch_norm_fusion(tag):
mul1 = Mul(Cast(sub1, mstype.float32), constant1)
assign_sub0 = AssignSub(var0, mul0)
assign_sub1 = AssignSub(var1, mul1)
depend0 = depend(tuple_getitem(batch_norm, 0), assign_sub0)
depend1 = depend(depend0, assign_sub1)
depend0 = F.depend(tuple_getitem(batch_norm, 0), assign_sub0)
depend1 = F.depend(depend0, assign_sub1)
outputs = make_tuple(depend1, tuple_getitem(batch_norm, 3), tuple_getitem(batch_norm, 4))
output = tuple_getitem(outputs, 0)
return output
......
......@@ -16,7 +16,7 @@ from mindspore.ops import Primitive
from mindspore.ops import operations as P
tuple_getitem = Primitive('tuple_getitem')
depend = Primitive('depend')
depend = P.Depend()
addn = P.AddN()
add = P.TensorAdd()
sub = P.Sub()
......
......@@ -16,7 +16,7 @@ from mindspore.ops import Primitive
from mindspore.ops import operations as P
tuple_getitem = Primitive('tuple_getitem')
depend = Primitive('depend')
depend = P.Depend()
addn = P.AddN()
add = P.TensorAdd()
sub = P.Sub()
......
......@@ -15,7 +15,7 @@
from mindspore.ops import Primitive
from mindspore.ops import operations as P
depend = Primitive('depend')
depend = P.Depend()
TransData = Primitive('TransData')
add = P.TensorAdd()
make_tuple = Primitive('make_tuple')
......
......@@ -20,9 +20,9 @@ import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor, Parameter
from mindspore.common.initializer import initializer
from mindspore.ops import Primitive
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops.operations import _grad_ops as G
from mindspore.ops import prim_attr_register, PrimitiveWithInfer
from ..ut_filter import non_graph_engine
......@@ -358,7 +358,7 @@ class StateNet(nn.Cell):
self.assign = P.Assign()
def construct(self, x):
x = Primitive('depend')(x, self.assign(self.s1, x + self.s1))
x = F.depend(x, self.assign(self.s1, x + self.s1))
self.s1 = self.sub(self.s1, x)
self.s2 = self.sub(self.s2, x)
return x
......
......@@ -132,7 +132,7 @@ def test_hypermap_add3_easy():
add3 = C.MultitypeFuncGraph('add')
partial = Primitive('partial')
partial = P.Partial()
@add3.register("Number", "Number", "Number")
......
......@@ -284,3 +284,21 @@ def vm_impl_zeros_like(self):
"""Generate vm_impl function for ZerosLike"""
def vm_impl(x):
return Tensor(np.zeros_like(x.asnumpy()))
@vm_impl_getters.register(P.Partial)
def vm_impl_partial(self):
"""Generate vm_impl function for Partial"""
def vm_impl(*args):
func = args[0].__call__
partial_func = functools.partial(func, *args[1:])
return partial_func
return vm_impl
@vm_impl_getters.register(P.Depend)
def vm_impl_depend(self):
"""Generate vm_impl function for Depend"""
def vm_impl(value, expr):
return value
return vm_impl
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册