提交 b12e6ff7 编写于 作者: Z zhaozhenlong

add operator diag and diag_part

上级 f1b72229
......@@ -178,6 +178,8 @@ const char kNameLARSUpdate[] = "LARSUpdate";
const char kNameRound[] = "Round";
const char kNamePrint[] = "Print";
const char kNameApplyFtrl[] = "ApplyFtrl";
const char kNameDiag[] = "Diag";
const char kNameDiagPart[] = "DiagPart";
// -----------------OpAdapter initialization--------------
std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_map() {
......@@ -357,7 +359,9 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{string(kNameDepthToSpace), ADPT_DESC(DepthToSpace)},
{string(kNameSign), ADPT_DESC(Sign)},
{string(kNameRound), ADPT_DESC(Round)},
{string(kNameApplyFtrl), ADPT_DESC(ApplyFtrl)}};
{string(kNameApplyFtrl), ADPT_DESC(ApplyFtrl)},
{string(kNameDiag), ADPT_DESC(Diag)},
{string(kNameDiagPart), ADPT_DESC(DiagPart)}};
#ifdef ENABLE_GE
adpt_map[string(kNamePrint)] = ADPT_DESC(Print);
#endif
......
......@@ -1173,6 +1173,16 @@ INPUT_MAP(ApplyFtrl) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INP
ATTR_MAP(ApplyFtrl) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
OUTPUT_MAP(ApplyFtrl) = {{0, OUTPUT_DESC(var)}};
// Diag
INPUT_MAP(Diag) = {{1, INPUT_DESC(x)}};
ATTR_MAP(Diag) = EMPTY_ATTR_MAP;
OUTPUT_MAP(Diag) = {{0, OUTPUT_DESC(y)}};
// DiagPart
INPUT_MAP(DiagPart) = {{1, INPUT_DESC(x)}};
ATTR_MAP(DiagPart) = EMPTY_ATTR_MAP;
OUTPUT_MAP(DiagPart) = {{0, OUTPUT_DESC(y)}};
#ifdef ENABLE_GE
// Print
INPUT_MAP(Print) = EMPTY_INPUT_MAP;
......
......@@ -435,6 +435,10 @@ DECLARE_OP_ADAPTER(Round)
DECLARE_OP_USE_OUTPUT(Round)
DECLARE_OP_ADAPTER(ApplyFtrl)
DECLARE_OP_USE_OUTPUT(ApplyFtrl)
DECLARE_OP_ADAPTER(Diag)
DECLARE_OP_USE_OUTPUT(Diag)
DECLARE_OP_ADAPTER(DiagPart)
DECLARE_OP_USE_OUTPUT(DiagPart)
#ifdef ENABLE_GE
DECLARE_OP_ADAPTER(Print)
DECLARE_OP_USE_DYN_INPUT(Print)
......
......@@ -408,3 +408,25 @@ def get_bprop_depth_to_space(self):
return (op(dout),)
return bprop
@bprop_getters.register(P.Diag)
def get_bprop_diag(self):
"""Generate bprop for Diag"""
op = P.DiagPart()
def bprop(x, out, dout):
return (op(dout),)
return bprop
@bprop_getters.register(P.DiagPart)
def get_bprop_diag_part(self):
"""Generate bprop for DiagPart"""
op = P.Diag()
def bprop(x, out, dout):
return (op(dout),)
return bprop
......@@ -20,7 +20,7 @@ A collection of operators to build nerual networks or computing functions.
"""
from .array_ops import (Argmax, Argmin, Cast, ConcatOffset, Concat,
Diag, DType, ExpandDims, Eye,
Diag, DiagPart, DType, ExpandDims, Eye,
Fill, GatherNd, GatherV2, InvertPermutation,
IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike,
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue,
......@@ -208,6 +208,7 @@ __all__ = [
"Cos",
"ACos",
"Diag",
"DiagPart",
'Eye',
'Assign',
'AssignAdd',
......
......@@ -1615,37 +1615,96 @@ class StridedSlice(PrimitiveWithInfer):
class Diag(PrimitiveWithInfer):
r"""
Extract or construct a diagonal array.
Construct a diagonal tensor with a given diagonal values.
If input is a 2-D tensor, returns the diagonal of the input with the given offset. If
input is a 1-D tensor, returns the array of diagonals. If you use this function
to extract the diagonal and want to write to the result array, see the more
detailed documentation for "numpy.diagonal", whether you return a copy or a
view depends on the version of numpy you are using.
Assume `input_x` has dimensions :math:`[D_1,... D_k]`, the output is a tensor of
rank 2k with dimensions :math:`[D_1,..., D_k, D_1,..., D_k]` where:
:math:`output[i_1,..., i_k, i_1,..., i_k] = input_x[i_1,..., i_k]` and 0 everywhere else.
Inputs:
- **input_x** (Tensor) - 1-D tensor or 2-D tensor.
- **input_x** (Tensor) - The input tensor.
Outputs:
Tensor.
Examples:
>>> input_x = Tensor([1, 2, 3, 4])
>>> diag = P.Diag()
>>> diag(x)
[[1, 0, 0, 0],
[0, 2, 0, 0],
[0, 0, 3, 0],
[0, 0, 0, 4]]
"""
@prim_attr_register
def __init__(self):
"""init Diag"""
def infer_type(self, x):
args = {"x_dtype": x}
validator.check_subclass('input_x', x, mstype.tensor)
validator.check_type_same(args, mstype.number_type)
return x
def infer_dtype(self, x_type):
validator.check_subclass('input_x', x_type, mstype.tensor)
return x_type
def infer_shape(self, x_shape):
validator.check("x rank", len(x_shape), "", 1, Rel.GE)
ret_shape = copy.deepcopy(x_shape)
ret_shape = ret_shape + ret_shape
return ret_shape
def infer_value(self, x):
if x is None:
return None
validator.check("input x rank", len(x.shape()), "", 1)
ret = np.diag(x.asnumpy())
return Tensor(ret)
class DiagPart(PrimitiveWithInfer):
r"""
Extract the diagonal part from given tensor.
Assume input has dimensions :math:`[D_1,..., D_k, D_1,..., D_k]`, the output is a tensor
of rank k with dimensions :math:`[D_1,..., D_k]` where:
:math:`output[i_1,..., i_k] = input[i_1,..., i_k, i_1,..., i_k]`.
Inputs:
- **input_x** (Tensor) - The input Tensor.
Outputs:
Tensor.
Examples
>>> input_x = Tensor([[1, 0, 0, 0],
>>> [0, 2, 0, 0],
>>> [0, 0, 3, 0],
>>> [0, 0, 0, 4]])
>>> diag_part = P.DiagPart()
>>> diag_part(x)
[1, 2, 3, 4]
"""
@prim_attr_register
def __init__(self):
"""init DiagPart"""
def infer_dtype(self, x_type):
validator.check_subclass('input_x', x_type, mstype.tensor)
return x_type
def infer_shape(self, x_shape):
if len(x_shape)%2 != 0 or \
not x_shape:
raise ValueError(f"DiagPart input rank must be non-zero and even, but got rank {len(x_shape)}, "
f"with shapes {x_shape}")
length = len(x_shape) // 2
ret_shape = x_shape[0:length]
return ret_shape
def infer_value(self, x):
validator.check("shape_length", len(x.shape()), "length", [1, 2], Rel.IN)
if x is None:
return None
validator.check("x rank", len(x.shape()), "", 2)
ret = np.diag(x.asnumpy())
return Tensor(ret)
......
......@@ -942,6 +942,16 @@ test_case_array_ops = [
Tensor(np.array([1], np.float32)),
Tensor(np.array([1], np.float32)))],
'desc_bprop': [[3,]]}),
('Diag', {
'block': P.Diag(),
'desc_inputs': [[4]],
'desc_bprop': [[4, 4]],
}),
('DiagPart', {
'block': P.DiagPart(),
'desc_inputs': [[4, 4]],
'desc_bprop': [[4]],
}),
]
test_case_other_ops = [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册