提交 879a5191 编写于 作者: W Wei Luning

updata signature

上级 03093778
......@@ -20,7 +20,6 @@ from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype
def scalar_add(x, y):
"""Implement `scalar_add`."""
return x + y
......@@ -117,25 +116,6 @@ def bool_or(x, y):
return x or y
def vm_compare(*args):
"""Implement `vm_compare` for tensor."""
obj_str = args[-1]
if obj_str == "shape":
fn = getattr(args[0].asnumpy(), obj_str)
return fn
if len(args) == 2:
fn = getattr(args[0].asnumpy(), obj_str)
return Tensor(fn())
if isinstance(args[0], Tensor):
fn = getattr(args[0].asnumpy(), obj_str)
y = args[1].asnumpy() if isinstance(args[1], Tensor) else args[1]
else:
obj_str = "__r" + obj_str[2:]
fn = getattr(args[1].asnumpy(), obj_str)
y = args[0]
return Tensor(np.array(fn(y)))
def make_list(*xs):
"""Implement `make_list`."""
return list(xs)
......
......@@ -262,6 +262,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
std::set<size_t> write_indices;
std::vector<TypePtr> input_types;
op_inputs.push_back(NewValueNode(function));
auto cast_type = parse::GetMixedPrecisionTargetType(func_graph);
// Assume, the write input of op is always the first input. We check if any write op,
// and add cast op on other inputs to keep the same type with assigned parameter.
for (size_t i = 0; i < args_spec_list.size(); ++i) {
......@@ -280,7 +281,6 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
TypePtr type = args_spec_list[i]->BuildType();
if (type && type->isa<RefType>()) {
auto cast_type = parse::GetMixedPrecisionTargetType(func_graph);
if (sig == SignatureEnumRW::kRWRead) {
auto source_tensor_type = type->cast<TensorTypePtr>();
if (source_tensor_type != nullptr) {
......@@ -300,8 +300,8 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
MS_EXCEPTION(TypeError) << "Function " << func_name << "'s input " << i << " should be a Parameter, but "
<< type->ToString();
}
MS_LOG(DEBUG) << "Function " << func_name << "'s input " << i << " " << param->DebugString(2) << " type "
<< args_spec_list[i]->ToString();
MS_LOG(DEBUG) << "Function " << func_name << "'s input " << i << " " << param->DebugString(2) << " abs "
<< args_spec_list[i]->ToString() << " type " << type->ToString();
input_types.push_back(type);
op_inputs.push_back(param);
}
......
......@@ -305,9 +305,6 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
dic[ATTR_SHAPE] = shape;
dic[ATTR_DTYPE] = arg_slice->BuildType();
dic[ATTR_VALUE] = BuildValue(arg_slice->BuildValue());
} else if (abs_base->isa<AbstractRef>()) {
auto value = abs_base->cast<AbstractRefPtr>()->ref();
dic = ConvertAbstractToPython(value);
} else if (abs_base->isa<AbstractEllipsis>()) {
dic[ATTR_SHAPE] = py::none();
dic[ATTR_DTYPE] = py::ellipsis();
......
......@@ -23,7 +23,7 @@ namespace mindspore {
REGISTER_PYBIND_DEFINE(FuncGraph, ([](const pybind11::module *m) {
// Define python "MetaFuncGraph_" class
(void)py::class_<MetaFuncGraph, std::shared_ptr<MetaFuncGraph>>(*m, "MetaFuncGraph_")
.def(py::init<std::string &>());
.def("set_signatures", &MetaFuncGraph::set_signatures, "Set primitive inputs signature.");
// Define python "FuncGraph" class
(void)py::class_<FuncGraph, FuncGraphPtr>(*m, "FuncGraph")
.def(py::init())
......
......@@ -48,22 +48,9 @@ void SyncData(const py::object &arg) {
}
} // namespace
std::map<std::string, py::object> PrimitivePy::hook_grad_;
static ValuePtr PyArgToValue(const py::object &arg) {
if (py::isinstance<SignatureEnumKind>(arg) &&
py::cast<SignatureEnumKind>(arg) == SignatureEnumKind::kKindEmptyDefaultValue) {
return nullptr;
}
return parse::data_converter::PyDataToValue(arg);
}
void PrimitivePy::set_signatures(
std::vector<std::tuple<std::string, SignatureEnumRW, SignatureEnumKind, py::object, SignatureEnumDType>> signatures) {
signatures_.clear();
for (auto &signature : signatures) {
auto [name, rw, kind, arg_default, dtype] = signature;
auto default_value = PyArgToValue(arg_default);
signatures_.emplace_back(name, rw, kind, default_value, dtype);
}
void PrimitivePy::set_signatures(const std::vector<Signature> &signatures) {
signatures_ = signatures;
set_has_signature(true);
}
......
......@@ -42,9 +42,7 @@ class PrimitivePy : public Primitive {
MS_DECLARE_PARENT(PrimitivePy, Primitive);
py::function GetBpropFunction();
void set_signatures(
std::vector<std::tuple<std::string, SignatureEnumRW, SignatureEnumKind, py::object, SignatureEnumDType>>
signatures);
void set_signatures(const std::vector<Signature> &signatures);
const std::vector<Signature> &signatures() const { return signatures_; }
......
......@@ -17,12 +17,26 @@
#include "ir/signature.h"
#include "pybind11/operators.h"
#include "pybind_api/api_register.h"
#include "pipeline/jit/parse/data_converter.h"
namespace py = pybind11;
namespace mindspore {
static ValuePtr PyArgToValue(const py::object &arg) {
if (py::isinstance<SignatureEnumKind>(arg) &&
py::cast<SignatureEnumKind>(arg) == SignatureEnumKind::kKindEmptyDefaultValue) {
return nullptr;
}
return parse::data_converter::PyDataToValue(arg);
}
// Bind SignatureEnumRW as a python class.
REGISTER_PYBIND_DEFINE(SignatureEnumRW, ([](const py::module *m) {
(void)py::class_<Signature>(*m, "Signature")
.def(py::init([](std::string name, SignatureEnumRW rw, SignatureEnumKind kind,
py::object arg_default, SignatureEnumDType dtype) {
auto default_value = PyArgToValue(arg_default);
return Signature(name, rw, kind, default_value, dtype);
}));
(void)py::enum_<SignatureEnumRW>(*m, "signature_rw", py::arithmetic())
.value("RW_READ", SignatureEnumRW::kRWRead)
.value("RW_WRITE", SignatureEnumRW::kRWWrite)
......
......@@ -393,3 +393,24 @@ class SparseTensor:
@property
def dense_shape(self):
return self.__dense_shape
def _vm_compare(*args):
"""Implement `vm_compare` for tensor."""
obj_str = args[-1]
if obj_str == "shape":
fn = getattr(args[0].asnumpy(), obj_str)
return fn
if len(args) == 2:
fn = getattr(args[0].asnumpy(), obj_str)
return Tensor(fn())
if isinstance(args[0], Tensor):
fn = getattr(args[0].asnumpy(), obj_str)
y = args[1].asnumpy() if isinstance(args[1], Tensor) else args[1]
else:
obj_str = "__r" + obj_str[2:]
fn = getattr(args[1].asnumpy(), obj_str)
y = args[0]
return Tensor(np.array(fn(y)))
tensor_operator_registry.register('vm_compare', _vm_compare)
......@@ -34,14 +34,17 @@ from .primitive import Primitive, PrimitiveWithInfer, prim_attr_register
from .vm_impl_registry import get_vm_impl_fn, vm_impl_registry
from .op_info_register import op_info_register, AkgGpuRegOp, AkgAscendRegOp, AiCPURegOp, TBERegOp, DataType
from .primitive import constexpr
from .._c_expression import signature_rw, signature_kind
from . import composite, operations, functional
from . import signature
__primitive__ = [
"prim_attr_register", "Primitive", "PrimitiveWithInfer",
"signature_rw", "signature_kind"
"prim_attr_register", "Primitive", "PrimitiveWithInfer", "signature"
]
__all__ = ["get_vm_impl_fn", "vm_impl_registry",
"op_info_register", "AkgGpuRegOp", "AkgAscendRegOp", "AiCPURegOp", "TBERegOp", "DataType",
"constexpr"]
__all__.extend(__primitive__)
__all__.extend(composite.__all__)
__all__.extend(operations.__all__)
__all__.extend(functional.__all__)
......@@ -25,9 +25,8 @@ from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, Map_, Mult
from ...common import dtype as mstype
from ...common.api import ms_function, _pynative_exec, _wrap_func
from .. import functional as F
from ...common.parameter import Parameter
from ...common.tensor import Tensor
from .. import signature as sig
__all__ = [EnvInstance_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_]
......@@ -348,6 +347,8 @@ class MultitypeFuncGraph(MultitypeFuncGraph_):
Args:
name (str): Operator name.
read_value (bool): If the registered function not need to set value on Parameter,
and all inputs will pass by value. Set `read_value` to True. Default: False.
Raises:
ValueError: Cannot find matching fn for the given args.
......@@ -358,16 +359,15 @@ class MultitypeFuncGraph(MultitypeFuncGraph_):
>>> add = MultitypeFuncGraph('add')
"""
def __init__(self, name):
def __init__(self, name, read_value=False):
MultitypeFuncGraph_.__init__(self, name)
self.entries = list()
if read_value:
self.set_signatures((
sig.make_sig('args', sig.sig_rw.RW_READ, sig.sig_kind.KIND_VAR_POSITIONAL),))
def __call__(self, *args):
def unwrap(arg):
if isinstance(arg, Parameter):
return arg.data
return arg
types = tuple(map(lambda arg: mstype.get_py_obj_dtype(unwrap(arg)), args))
types = tuple(map(mstype.get_py_obj_dtype, args))
for sigs, fn in self.entries:
if len(sigs) != len(types):
continue
......
......@@ -19,7 +19,7 @@ from ...composite import base
from ... import functional as F
add = base.MultitypeFuncGraph('add')
add = base.MultitypeFuncGraph('add', True)
"""`add` is a metafuncgraph object which will add two objects according to input type using ".register" decorator."""
......
......@@ -19,7 +19,7 @@ from ...composite import base
from ... import functional as F
div = base.MultitypeFuncGraph("div")
div = base.MultitypeFuncGraph("div", True)
"""
div is a metafuncgraph object which will div two objects according to input type
using ".register" decorator
......
......@@ -19,7 +19,7 @@ from ...composite import base
from ... import functional as F
equal = base.MultitypeFuncGraph("equal")
equal = base.MultitypeFuncGraph("equal", True)
"""
equal is a metafuncgraph object which will determine if two objects are equal according to input type
using ".register" decorator
......
......@@ -19,7 +19,7 @@ from ...composite import base
from ... import functional as F
floordiv = base.MultitypeFuncGraph("floordiv")
floordiv = base.MultitypeFuncGraph("floordiv", True)
"""
`floordiv` is a metafuncgraph object which will compute the floordiv of two objects
using ".register" decorator.
......
......@@ -19,7 +19,7 @@ from .. import base
from ... import functional as F
getitem = base.MultitypeFuncGraph('getitem')
getitem = base.MultitypeFuncGraph('getitem', True)
"""
getitem is a metafuncgraph object which will get item from an object according to input type
using ".register" decorator.
......
......@@ -19,7 +19,7 @@ from mindspore.ops import functional as F
# greater_equal is a metagraph object which will determine if two objects are greater_equal according to input type
# using ".register" decorator
greater_equal = base.MultitypeFuncGraph("greater_equal")
greater_equal = base.MultitypeFuncGraph("greater_equal", True)
@greater_equal.register("Number", "Number")
......
......@@ -19,7 +19,7 @@ from mindspore.ops import functional as F
# greater is a metafuncgraph object which will determine if two objects are greater according to input type
# using ".register" decorator
greater = base.MultitypeFuncGraph("greater")
greater = base.MultitypeFuncGraph("greater", True)
@greater.register("Number", "Number")
......
......@@ -19,7 +19,7 @@ from . import _constexpr_utils as const_utils
from ... import functional as F
from ...composite import base
in_ = base.MultitypeFuncGraph("in")
in_ = base.MultitypeFuncGraph("in", True)
"""
in_ is a metafuncgraph object which will determine if a in b
using ".register" decorator
......
......@@ -19,7 +19,7 @@ from mindspore.ops import functional as F
# less_equal is a metagraph object which will determine if two objects are less_equal according to input type
# using ".register" decorator
less_equal = base.MultitypeFuncGraph("less_equal")
less_equal = base.MultitypeFuncGraph("less_equal", True)
@less_equal.register("Number", "Number")
......
......@@ -19,7 +19,7 @@ from mindspore.ops import functional as F
# less is a metafuncgraph object which will determine if two objects are less according to input type
# using ".register" decorator
less = base.MultitypeFuncGraph("less")
less = base.MultitypeFuncGraph("less", True)
@less.register("Number", "Number")
......
......@@ -19,7 +19,7 @@ from mindspore.ops import functional as F
# logical_not is a metagraph object which will generate function according to input type
# using ".register" decorator
logical_not = base.MultitypeFuncGraph("logical_not")
logical_not = base.MultitypeFuncGraph("logical_not", True)
@logical_not.register("Number")
......
......@@ -19,7 +19,7 @@ from mindspore.ops import functional as F
# logical_and is a metagraph object which will generate function according to input type
# using ".register" decorator
logical_and = base.MultitypeFuncGraph("logical_and")
logical_and = base.MultitypeFuncGraph("logical_and", True)
@logical_and.register("Number", "Number")
......
......@@ -19,7 +19,7 @@ from mindspore.ops import functional as F
# logical_or is a metagraph object which will generate function according to input type
# using ".register" decorator
logical_or = base.MultitypeFuncGraph("logical_or")
logical_or = base.MultitypeFuncGraph("logical_or", True)
@logical_or.register("Number", "Number")
......
......@@ -19,7 +19,7 @@ from ...composite import base
from ... import functional as F
mod = base.MultitypeFuncGraph("mod")
mod = base.MultitypeFuncGraph("mod", True)
"""
`mod` is a metafuncgraph object which will compute the mod of two objects
using ".register" decorator.
......
......@@ -19,7 +19,7 @@ from ...composite import base
from ... import functional as F
mul = base.MultitypeFuncGraph("mul")
mul = base.MultitypeFuncGraph("mul", True)
"""
`mul` is a metafuncgraph object which will multiply two objects according to input type
using ".register" decorator.
......
......@@ -19,7 +19,7 @@ from ...composite import base
from ... import functional as F
negative = base.MultitypeFuncGraph("negative")
negative = base.MultitypeFuncGraph("negative", True)
"""
`negative` is a metafuncgraph object which will give the negative of an object according to its input type
using ".register" decorator.
......
......@@ -19,7 +19,7 @@ from ...composite import base
from ... import functional as F
not_equal = base.MultitypeFuncGraph("not_equal")
not_equal = base.MultitypeFuncGraph("not_equal", True)
"""
not_equal is a metafuncgraph object which will determine if two objects are not_equal according to input type
using ".register" decorator
......
......@@ -22,7 +22,7 @@ from ... import functional as F
from ... import operations as P
ones_like_leaf = base.MultitypeFuncGraph('ones_like_leaf')
ones_like_leaf = base.MultitypeFuncGraph('ones_like_leaf', True)
"""
`ones_like_leaf` is a metafuncgraph object which will generate a tensor filled with one according to its input type
using ".register" decorator.
......
......@@ -19,7 +19,7 @@ from ...composite import base
from ... import functional as F
pow_ = base.MultitypeFuncGraph("pow")
pow_ = base.MultitypeFuncGraph("pow", True)
"""
`pow` is a metafuncgraph object which will compute the pow of two objects
using ".register" decorator.
......
......@@ -19,7 +19,7 @@ from ...composite import base
from ... import functional as F
sub = base.MultitypeFuncGraph("sub")
sub = base.MultitypeFuncGraph("sub", True)
"""
`sub` is a metafuncgraph object which will compute the subtraction of two objects
using ".register" decorator.
......
......@@ -18,7 +18,7 @@ from mindspore.ops.composite import base
# uadd is a metagraph object which will return operation result regarding input
# using ".register" decorator
uadd = base.MultitypeFuncGraph("uadd")
uadd = base.MultitypeFuncGraph("uadd", True)
@uadd.register("Tensor")
@uadd.register("Number")
......
......@@ -19,7 +19,7 @@ from ...composite import base
from ... import functional as F
zeros_like_leaf = base.MultitypeFuncGraph('zeros_like_leaf')
zeros_like_leaf = base.MultitypeFuncGraph('zeros_like_leaf', True)
"""
`zeros_like_leaf` is a metafuncgraph object which will generate a tensor filled with one according to its input type
using ".register" decorator.
......
......@@ -21,7 +21,6 @@ from mindspore.common._register_for_tensor import tensor_operator_registry
from .primitive import Primitive
from . import operations as P
from .operations import _grad_ops
from .._extends import builtin_operations as BP
typeof = Primitive('typeof')
hastype = Primitive('hastype')
......@@ -182,5 +181,6 @@ tensor_operator_registry.register('__gt__', tensor_gt)
tensor_operator_registry.register('__ge__', tensor_ge)
tensor_operator_registry.register('shape', shape)
# support GE backend for no compare operators
tensor_operator_registry.register('vm_compare', BP.vm_compare)
tensor_operator_registry.register('cast', cast)
__all__ = [name for name in dir() if name[0] != "_"]
......@@ -15,8 +15,7 @@
"""Operators for gradients."""
from ..._c_expression import signature_rw as sig_rw
from ..._c_expression import signature_kind as sig_kind
from .. import signature as sig
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
from ..._checkparam import Validator as validator, Rel
from .._utils import get_concat_offset
......@@ -1500,7 +1499,7 @@ class RefToEmbed(Primitive):
>>> return key, self.weight
"""
__mindspore_signature__ = (
('variable', sig_rw.RW_REF, sig_kind.KIND_POSITIONAL_KEYWORD),
sig.make_sig('variable', sig.sig_rw.RW_REF),
)
@prim_attr_register
......
......@@ -28,10 +28,7 @@ import numpy as np
from .._utils import get_concat_offset
from ..operations.math_ops import _infer_shape_reduce
from ..primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op
from ..._c_expression import signature_dtype as sig_dtype
from ..._c_expression import signature_kind as sig_kind
from ..._c_expression import signature_rw as sig_rw
from ..._c_expression import typing
from .. import signature as sig
from ..._checkparam import Rel
from ..._checkparam import Validator as validator
from ...common import dtype as mstype
......@@ -44,9 +41,9 @@ class _ScatterOp(PrimitiveWithInfer):
Define Scatter operators
"""
__mindspore_signature__ = (
('x', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1),
('updates', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
sig.make_sig('x', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
sig.make_sig('indices', dtype=sig.sig_dtype.T1),
sig.make_sig('updates', dtype=sig.sig_dtype.T)
)
def _check_scatter_shape(self, x_shape, indices_shape, updates_shape, prim_name):
......@@ -1396,7 +1393,7 @@ class Tile(PrimitiveWithInfer):
validator.check_value_type("shape", multiples_v, [tuple], self.name)
for i, multiple in enumerate(multiples_v):
validator.check_value_type("multiples[%d]" % i, multiple, [int], self.name)
validator.check_value_type("x[\'dtype\']", x["dtype"], typing.TensorType, self.name)
validator.check_value_type("x[\'dtype\']", x["dtype"], mstype.tensor_type, self.name)
len_sub = len(multiples_v) - len(x_shp)
multiples_w = None
if len_sub == 0:
......
......@@ -18,9 +18,7 @@
import copy
import numpy as np
from ... import context
from ..._c_expression import signature_rw as sig_rw
from ..._c_expression import signature_kind as sig_kind
from ..._c_expression import signature_dtype as sig_dtype
from .. import signature as sig
from ..._checkparam import Validator as validator
from ..._checkparam import Rel
from ...common import dtype as mstype
......@@ -68,7 +66,7 @@ class _BinaryOp(PrimitiveWithInfer):
Define binary operators.
"""
__mindspore_signature__ = (sig_dtype.T, sig_dtype.T)
__mindspore_signature__ = (sig.sig_dtype.T, sig.sig_dtype.T)
@prim_attr_register
def __init__(self):
......@@ -186,8 +184,8 @@ class AssignAdd(PrimitiveWithInfer):
>>> net(value)
"""
__mindspore_signature__ = (
('variable', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
sig.make_sig('x', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
sig.make_sig('value', dtype=sig.sig_dtype.T)
)
@prim_attr_register
......@@ -237,8 +235,8 @@ class AssignSub(PrimitiveWithInfer):
"""
__mindspore_signature__ = (
('variable', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
sig.make_sig('variable', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
sig.make_sig('value', dtype=sig.sig_dtype.T)
)
@prim_attr_register
......@@ -264,8 +262,8 @@ class _Reduce(PrimitiveWithInfer):
"""
__mindspore_signature__ = (
('input_x', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD),
('axis', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, ()),
sig.make_sig('input_x'),
sig.make_sig('axis', default=())
)
@prim_attr_register
......
此差异已折叠。
......@@ -15,9 +15,7 @@
"""Other operators."""
import functools
from ..._c_expression import signature_rw as sig_rw
from ..._c_expression import signature_kind as sig_kind
from ..._c_expression import signature_dtype as sig_dtype
from .. import signature as sig
from ..._checkparam import Validator as validator, Rel
from ...common import dtype as mstype
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
......@@ -53,8 +51,8 @@ class Assign(Primitive):
>>> net(x)
"""
__mindspore_signature__ = (
('variable', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
sig.make_sig('variable', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
sig.make_sig('value', dtype=sig.sig_dtype.T)
)
@prim_attr_register
......
......@@ -14,17 +14,13 @@
# ============================================================================
"""primitive"""
import inspect
import copy
from mindspore.common.api import _wrap_func
from mindspore.common._register_for_tensor import tensor_operator_registry
from mindspore import context
from .._c_expression import Primitive_, real_run_op, prim_type
from .._c_expression import signature_rw as sig_rw
from .._c_expression import signature_kind as sig_kind
from .._c_expression import signature_dtype as sig_dtype
from . import signature as sig
class Primitive(Primitive_):
"""
......@@ -54,24 +50,21 @@ class Primitive(Primitive_):
self._update_parameter = False
Primitive_.__init__(self, name, self)
if hasattr(self.__class__, '__mindspore_signature__'):
sig = self._fill_signature(self.__class__.__mindspore_signature__)
self.set_signatures(sig)
out = self._fill_signature(self.__class__.__mindspore_signature__)
self.set_signatures(out)
def _fill_signature(self, signatures):
"""fills signature."""
signatures_new = []
for signature in signatures:
if isinstance(signature, sig_dtype):
signatures_new.append(("argument", sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD,
sig_kind.KIND_EMPTY_DEFAULT_VALUE, signature))
if isinstance(signature, sig.Signature):
signatures_new.append(signature)
elif isinstance(signature, sig.sig_dtype):
signatures_new.append(sig.make_sig(dtype=signature))
else:
if len(signature) < 3:
raise ValueError(f"[Internal Error]Signature for one parameter len must > 3, but {signature}")
if len(signature) == 3:
signature += (sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T_EMPTY_DEFAULT_VALUE)
if len(signature) == 4:
signature += (sig_dtype.T_EMPTY_DEFAULT_VALUE,)
signatures_new.append(signature)
signatures_new.append(sig.make_sig(*signature))
return tuple(signatures_new)
def _clone(self):
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""signature"""
from .._c_expression import signature_rw as sig_rw
from .._c_expression import signature_kind as sig_kind
from .._c_expression import signature_dtype as sig_dtype
from .._c_expression import Signature
def make_sig(name="var", rw=sig_rw.RW_READ,
kind=sig_kind.KIND_POSITIONAL_KEYWORD,
default=sig_kind.KIND_EMPTY_DEFAULT_VALUE,
dtype=sig_dtype.T_EMPTY_DEFAULT_VALUE):
"""
Make signature for one argument.
See `ApplyMomentum` in `mindspore.ops.operation.nn_ops` as a example.
Args:
name (bool): Argument name. Default: "var".
rw (:class:`mindspore.ops.signature.sig_rw`): Tag the argument attribute for write and read. Choose in
[sig_rw.RW_READ, sig_rw.RW_WRITE, sig_rw.RW_REF]`, tag if the argument will update the input.
`sig_rw.RW_READ` for read only argument and `sig_rw.RW_WRITE` for write only argument. `sig_rw.RW_READ`
for the argument both need read and write. Default: sig_rw.RW_READ.
kind (:class:`mindspore.ops.signature.kind`): Choose in `[signature_kind.KIND_POSITIONAL_KEYWORD,
signature_kind.KIND_VAR_POSITIONAL, signature_kind.KIND_KEYWORD_ONLY, signature_kind.KIND_VAR_KEYWARD]`.
The meaning is the same as python argument kind, please refer to the python document.
Default: sig_kind.KIND_POSITIONAL_KEYWORD.
default (Any): The default value of argument or `sig_kind.KIND_EMPTY_DEFAULT_VALUE` for no default value.
Default: sig_kind.KIND_EMPTY_DEFAULT_VALUE.
dtype (:class:`mindspore.ops.signature.sig_dtype`): Choose in `signature_dtype.T` or
`signature_dtype.T1` to `signature_dtype.T9` or `sig_dtype.T_EMPTY_DEFAULT_VALUE` for no constraints.
If the signature of one argument is the same as another argument, we will perform auto type convert
between them. If any `sig_rw.RW_WRITE` argument, we will try to convert the other arguments to the
`sig_rw.RW_WRITE` argument. Default: sig_dtype.T_EMPTY_DEFAULT_VALUE.
Returns:
:class:`mindspore.ops.signature.Signature`, signature for one argument.
"""
return Signature(name, rw, kind, default, dtype)
......@@ -136,13 +136,15 @@ class NetForCast(nn.Cell):
super(NetForCast, self).__init__()
self.concat = P.Concat()
self.x1 = Tensor(1.0, mstype.float32)
self.x2 = Parameter(Tensor(np.zeros([1, 10]).astype(np.float32)), name='x2')
def construct(self, x0):
x = self.x1 * x0
x = self.x1 * x0 * self.x2
return x
def test_cast():
context.set_context(save_graphs=True)
x = Tensor(np.ones([1, 16, 10, 10]).astype(np.float32) * 0.01)
net = NetForCast()
net.add_flags_recursive(fp16=True)
......
......@@ -16,9 +16,7 @@
import functools
import numpy as np
import pytest
from mindspore._c_expression import signature_dtype as sig_dtype
from mindspore._c_expression import signature_kind as sig_kind
from mindspore._c_expression import signature_rw as sig_rw
from mindspore.ops.signature import sig_rw, sig_dtype, make_sig
import mindspore as ms
from mindspore import Tensor
......@@ -126,9 +124,9 @@ class CustomOP(PrimitiveWithInfer):
class CustomOP2(PrimitiveWithInfer):
__mindspore_signature__ = (
('p1', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('p2', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('p3', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
make_sig('p1', sig_rw.RW_WRITE, dtype=sig_dtype.T),
make_sig('p2', dtype=sig_dtype.T),
make_sig('p3', dtype=sig_dtype.T),
)
@prim_attr_register
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册