提交 ac78ac97 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2297 add vm support for operators include MatrixDiag, MatrixDiagPart etc

Merge pull request !2297 from jiangjinsheng/vm_matrixdiag
...@@ -124,7 +124,10 @@ static std::map<string, string> tbe_func_adapter_map = { ...@@ -124,7 +124,10 @@ static std::map<string, string> tbe_func_adapter_map = {
{"a_cos_grad", "acos_grad"}, {"a_cos_grad", "acos_grad"},
{"histogram_fixed_width", "histogram_fixed_width_d"}, {"histogram_fixed_width", "histogram_fixed_width_d"},
{"broadcast_to", "broadcast_to_d"}, {"broadcast_to", "broadcast_to_d"},
{"inplace_update", "inplace_update_d"}}; {"inplace_update", "inplace_update_d"},
{"matrix_diag", "matrix_diag_d"},
{"matrix_diag_part", "matrix_diag_part_d"},
{"matrix_set_diag", "matrix_set_diag_d"}};
void TbeAdapter::NormalizeFuncName(std::string *func_name) { void TbeAdapter::NormalizeFuncName(std::string *func_name) {
if (func_name == nullptr) { if (func_name == nullptr) {
......
...@@ -31,9 +31,12 @@ from mindspore.ops import _selected_ops ...@@ -31,9 +31,12 @@ from mindspore.ops import _selected_ops
from ..cell import Cell from ..cell import Cell
from .activation import get_activation from .activation import get_activation
from ..._checkparam import Validator as validator from ..._checkparam import Validator as validator
from ..._checkparam import Rel
__all__ = ['Dropout', 'Flatten', 'Dense', 'ClipByNorm', 'Norm', 'OneHot', 'Pad', 'Unfold'] __all__ = ['Dropout', 'Flatten', 'Dense', 'ClipByNorm', 'Norm', 'OneHot', 'Pad', 'Unfold',
'MatrixDiag', 'MatrixDiagPart', 'MatrixSetDiag']
class Dropout(Cell): class Dropout(Cell):
r""" r"""
...@@ -527,3 +530,112 @@ class Unfold(Cell): ...@@ -527,3 +530,112 @@ class Unfold(Cell):
ret = self.extract_image_patches(x_transpose) ret = self.extract_image_patches(x_transpose)
ret_transpose = self.transpose(ret, self.format_NCHW) ret_transpose = self.transpose(ret, self.format_NCHW)
return ret_transpose return ret_transpose
@constexpr
def _get_matrix_diag_assist(x_shape, x_dtype):
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, "_get_matrix_diag_assist")
base_eye = np.eye(x_shape[-1], x_shape[-1]).reshape(-1)
assist = np.tile(base_eye, x_shape[:-1]).reshape(x_shape + (x_shape[-1],))
return Tensor(assist, x_dtype)
@constexpr
def _get_matrix_diag_part_assist(x_shape, x_dtype):
validator.check_integer("x rank", len(x_shape), 2, Rel.GE, "_get_matrix_diag_part_assist")
base_eye = np.eye(x_shape[-2], x_shape[-1]).reshape(-1)
assist = np.tile(base_eye, x_shape[:-2]).reshape(x_shape)
return Tensor(assist, x_dtype)
class MatrixDiag(Cell):
"""
Returns a batched diagonal tensor with a given batched diagonal values.
Inputs:
- **x** (Tensor) - The diagonal values. It can be of the following data types:
float32, float16, int32, int8, uint8.
Outputs:
Tensor, same type as input `x`. The shape should be x.shape + (x.shape[-1], ).
Examples:
>>> x = Tensor(np.array([1, -1]), mstype.float32)
>>> matrix_diag = nn.MatrixDiag()
>>> result = matrix_diag(x)
[[1. 0.]
[0. -1.]]
"""
def __init__(self):
super(MatrixDiag, self).__init__()
self.matrix_diag = inner.MatrixDiag()
self.dtype = P.DType()
def construct(self, input_x):
x_shape = F.shape(input_x)
x_dtype = self.dtype(input_x)
assist = _get_matrix_diag_assist(x_shape, x_dtype)
out_matrix_diag = self.matrix_diag(input_x, assist)
return out_matrix_diag
class MatrixDiagPart(Cell):
r"""
Returns the batched diagonal part of a batched tensor.
Inputs:
- **x** (Tensor) - The batched tensor. It can be of the following data types:
float32, float16, int32, int8, uint8.
Outputs:
Tensor, same type as input `x`. The shape should be x.shape[:-2] + [min(x.shape[-2:])].
Examples:
>>> x = Tensor([[[-1, 0], [0, 1]], [-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32)
>>> matrix_diag_part = nn.MatrixDiagPart()
>>> result = matrix_diag_part(x)
[[-1., 1.], [-1., 1.], [-1., 1.]]
"""
def __init__(self):
super(MatrixDiagPart, self).__init__()
self.matrix_diag_part = inner.MatrixDiagPart()
self.dtype = P.DType()
def construct(self, input_x):
x_shape = F.shape(input_x)
x_dtype = self.dtype(input_x)
assist = _get_matrix_diag_part_assist(x_shape, x_dtype)
out_matrix_diag_part = self.matrix_diag_part(input_x, assist)
return out_matrix_diag_part
class MatrixSetDiag(Cell):
r"""
Modify the batched diagonal part of a batched tensor.
Inputs:
- **x** (Tensor) - The batched tensor. It can be of the following data types:
float32, float16, int32, int8, uint8.
- **diagonal** (Tensor) - The diagonal values.
Outputs:
Tensor, same type as input `x`. The shape same as `x`.
Examples:
>>> x = Tensor([[[-1, 0], [0, 1]], [-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32)
>>> diagonal = Tensor([[-1., 2.], [-1., 1.], [-1., 1.]], mindspore.float32)
>>> matrix_set_diag = nn.MatrixSetDiag()
>>> result = matrix_set_diag(x, diagonal)
[[[-1, 0], [0, 2]], [-1, 0], [0, 1]], [[-1, 0], [0, 1]]]
"""
def __init__(self):
super(MatrixSetDiag, self).__init__()
self.matrix_set_diag = inner.MatrixSetDiag()
self.dtype = P.DType()
def construct(self, input_x, diagonal):
x_shape = F.shape(input_x)
x_dtype = self.dtype(input_x)
assist = _get_matrix_diag_part_assist(x_shape, x_dtype)
out_matrix_set_diag = self.matrix_set_diag(input_x, diagonal, assist)
return out_matrix_set_diag
...@@ -264,3 +264,6 @@ from .inplace_update import _inplace_update_tbe ...@@ -264,3 +264,6 @@ from .inplace_update import _inplace_update_tbe
from .splitv import _split_v_tbe from .splitv import _split_v_tbe
from .in_top_k import _in_top_k_tbe from .in_top_k import _in_top_k_tbe
from .lin_space import _lin_space_tbe from .lin_space import _lin_space_tbe
from .matrix_diag import _matrix_diag_tbe
from .matrix_diag_part import _matrix_diag_part_tbe
from .matrix_set_diag import _matrix_set_diag_tbe
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MatrixDiagD op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
matrix_diag_d_op_info = TBERegOp("MatrixDiag") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("matrix_diag_d.so") \
.compute_cost(10) \
.kernel_name("matrix_diag_d") \
.partial_flag(True) \
.input(0, "x", False, "required", "all") \
.input(1, "assist", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.I8_5HD) \
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_5HD, DataType.U8_5HD, DataType.U8_5HD) \
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \
.get_op_info()
@op_info_register(matrix_diag_d_op_info)
def _matrix_diag_tbe():
"""MatrixDiagD TBE register"""
return
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MatrixDiagPartD op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
matrix_diag_part_d_op_info = TBERegOp("MatrixDiagPart") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("matrix_diag_part_d.so") \
.compute_cost(10) \
.kernel_name("matrix_diag_part_d") \
.partial_flag(True) \
.input(0, "x", False, "required", "all") \
.input(1, "assist", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.I8_5HD) \
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_5HD, DataType.U8_5HD, DataType.U8_5HD) \
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \
.get_op_info()
@op_info_register(matrix_diag_part_d_op_info)
def _matrix_diag_part_tbe():
"""MatrixDiagPartD TBE register"""
return
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MatrixSetDiagD op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
matrix_diag_d_op_info = TBERegOp("MatrixSetDiag") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("matrix_diag_d.so") \
.compute_cost(10) \
.kernel_name("matrix_diag_d") \
.partial_flag(True) \
.input(0, "x", False, "required", "all") \
.input(1, "diagonal", False, "required", "all") \
.input(2, "assist", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.I8_5HD, DataType.I8_5HD) \
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_5HD, DataType.U8_5HD, DataType.U8_5HD, DataType.U8_5HD) \
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \
.get_op_info()
@op_info_register(matrix_diag_d_op_info)
def _matrix_set_diag_tbe():
"""MatrixSetDiagD TBE register"""
return
...@@ -367,3 +367,144 @@ class LinSpace(PrimitiveWithInfer): ...@@ -367,3 +367,144 @@ class LinSpace(PrimitiveWithInfer):
args = {"assist": assist, "start": start, "stop": stop} args = {"assist": assist, "start": start, "stop": stop}
validator.check_tensor_type_same(args, (mstype.float32,), self.name) validator.check_tensor_type_same(args, (mstype.float32,), self.name)
return assist return assist
class MatrixDiag(PrimitiveWithInfer):
"""
Returns a batched diagonal tensor with a given batched diagonal values.
Inputs:
- **x** (Tensor) - A tensor which to be element-wise multi by `assist`. It can be of the following data types:
float32, float16, int32, int8, uint8.
- **assist** (Tensor) - A eye tensor of the same type as `x`. It's rank must greater than or equal to 2 and
it's last dimension must equal to the second to last dimension.
Outputs:
Tensor, has the same type and shape as input `assist`.
Examples:
>>> x = Tensor(np.array([1, -1]), mstype.float32)
>>> assist = Tensor(np.arange(-12, 0).reshape(3, 2, 2), mindspore.float32)
>>> matrix_diag = P.MatrixDiag()
>>> result = matrix_diag(x, assist)
[[[-12. 11.]
[-10. 9.]]
[[ -8. 7.]
[ -6. 5.]]
[[ -4. 3.]
[ -2. 1.]]]
"""
@prim_attr_register
def __init__(self):
"""init MatrixDiag"""
def infer_dtype(self, x_dtype, assist_dtype):
valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
args = {"x": x_dtype, "assist": assist_dtype}
validator.check_tensor_type_same(args, valid_type, self.name)
return x_dtype
def infer_shape(self, x_shape, assist_shape):
validator.check_integer("assist rank", len(assist_shape), 2, Rel.GE, self.name)
validator.check('rank of x', len(x_shape)+1,
'rank of assist', len(assist_shape), Rel.LE, self.name)
validator.check('assist\'s penultimate dimension', assist_shape[-2], 'assist\'s last dimension',
assist_shape[-1], Rel.EQ, self.name)
r_end_dim = -len(x_shape)
r_idx = -1
while r_idx >= r_end_dim:
if x_shape[r_idx] != 1:
validator.check("reverse x dim %d" % r_idx, x_shape[r_idx], "reverse assist dim %d" %
assist_shape[r_idx-1], assist_shape[r_idx-1], Rel.EQ, self.name)
r_idx = r_idx - 1
return assist_shape
class MatrixDiagPart(PrimitiveWithInfer):
r"""
Returns the batched diagonal part of a batched tensor.
Inputs:
- **x** (Tensor) - The batched tensor. It can be of the following data types:
float32, float16, int32, int8, uint8.
- **assist** (Tensor) - A eye tensor of the same type as `x`. With shape same as `x`.
Outputs:
Tensor, data type same as input `x`. The shape should be x.shape[:-2] + [min(x.shape[-2:])].
Examples:
>>> x = Tensor([[[-1, 0], [0, 1]], [-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32)
>>> assist = Tensor(np.arange(-12, 0).reshape(3, 2, 2), mindspore.float32)
>>> matrix_diag_part = P.MatrixDiagPart()
>>> result = matrix_diag_part(x, assist)
[[12., -9.], [8., -5.], [4., -1.]]
"""
@prim_attr_register
def __init__(self):
"""init MatrixDiagPart"""
def infer_dtype(self, x_dtype, assist_dtype):
valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
args = {"x": x_dtype, "assist": assist_dtype}
validator.check_tensor_type_same(args, valid_type, self.name)
return x_dtype
def infer_shape(self, x_shape, assist_shape):
validator.check_integer("x rank", len(x_shape), 2, Rel.GE, self.name)
validator.check("x shape", x_shape, "assist shape", assist_shape, Rel.EQ, self.name)
if assist_shape[-2] < assist_shape[-1]:
out_shape = assist_shape[:-1]
else:
out_shape = assist_shape[:-2] + assist_shape[-1:]
return out_shape
class MatrixSetDiag(PrimitiveWithInfer):
r"""
Modify the batched diagonal part of a batched tensor.
Inputs:
- **x** (Tensor) - The batched tensor. It can be of the following data types:
float32, float16, int32, int8, uint8.
- **assist** (Tensor) - A eye tensor of the same type as `x`. With shape same as `x`.
- **diagonal** (Tensor) - The diagonal values.
Outputs:
Tensor, data type same as input `x`. The shape same as `x`.
Examples:
>>> x = Tensor([[[-1, 0], [0, 1]], [-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32)
>>> diagonal = Tensor([[-1., 2.], [-1., 1.], [-1., 1.]], mindspore.float32)
>>> matrix_set_diag = P.MatrixSetDiag()
>>> result = matrix_set_diag(x, diagonal)
[[[-1, 0], [0, 2]], [-1, 0], [0, 1]], [[-1, 0], [0, 1]]]
"""
@prim_attr_register
def __init__(self):
"""init MatrixSetDiag"""
def infer_dtype(self, x_dtype, diagonal_dtype, assist_dtype):
valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
args = {"x": x_dtype, "diagonal": diagonal_dtype, "assist": assist_dtype}
validator.check_tensor_type_same(args, valid_type, self.name)
return x_dtype
def infer_shape(self, x_shape, diagonal_shape, assist_shape):
validator.check_integer("x rank", len(x_shape), 2, Rel.GE, self.name)
validator.check("x shape", x_shape, "assist shape", assist_shape, Rel.EQ, self.name)
if x_shape[-2] < x_shape[-1]:
validator.check("x shape excluding the last dimension", x_shape[:-1], "diagnoal shape",
diagonal_shape, Rel.EQ, self.name)
else:
validator.check("x shape excluding the second to last dimension", x_shape[:-2]+x_shape[-1:],
"diagonal shape", diagonal_shape, Rel.EQ, self.name)
return assist_shape
...@@ -370,6 +370,7 @@ def test_conv2d_same_primitive(): ...@@ -370,6 +370,7 @@ def test_conv2d_same_primitive():
super(Conv2DSameNet, self).__init__() super(Conv2DSameNet, self).__init__()
self.conv1 = nn.Conv2d(16, 64, (1, 41), (1, 4), "same", 0, 1, has_bias=True) self.conv1 = nn.Conv2d(16, 64, (1, 41), (1, 4), "same", 0, 1, has_bias=True)
self.conv2 = nn.Conv2d(16, 64, (1, 41), (1, 4), "same", 0, 1, has_bias=True) self.conv2 = nn.Conv2d(16, 64, (1, 41), (1, 4), "same", 0, 1, has_bias=True)
def construct(self, x, y): def construct(self, x, y):
r1 = self.conv1(x) r1 = self.conv1(x)
r2 = self.conv2(y) r2 = self.conv2(y)
...@@ -576,6 +577,22 @@ test_cases = [ ...@@ -576,6 +577,22 @@ test_cases = [
Tensor(np.ones([1, 3, 4, 4], np.float32)), Tensor(np.ones([1, 3, 4, 4], np.float32)),
Tensor(np.ones(3, np.float32))], Tensor(np.ones(3, np.float32))],
}), }),
('MatrixDiag', {
'block': nn.MatrixDiag(),
'desc_inputs': [Tensor(np.array([1, 2, 3]).astype(np.float32))],
'skip': ['backward']
}),
('MatrixDiagPart', {
'block': nn.MatrixDiagPart(),
'desc_inputs': [Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.float32))],
'skip': ['backward']
}),
('MatrixSetDiag', {
'block': nn.MatrixSetDiag(),
'desc_inputs': [Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.float32)),
Tensor(np.array([1, 2]).astype(np.float32))],
'skip': ['backward']
}),
] ]
test_cases_for_verify_exception = [ test_cases_for_verify_exception = [
......
...@@ -1612,6 +1612,25 @@ test_case_array_ops = [ ...@@ -1612,6 +1612,25 @@ test_case_array_ops = [
Tensor(5, mstype.int32)], Tensor(5, mstype.int32)],
'skip': ['backward'], 'skip': ['backward'],
}), }),
('MatrixDiag', {
'block': inner.MatrixDiag(),
'desc_inputs': [Tensor(np.array([1, -1]), mstype.float32),
Tensor(np.arange(-12, 0).reshape(3, 2, 2), mstype.float32)],
'skip': ['backward'],
}),
('MatrixDiagPart', {
'block': inner.MatrixDiagPart(),
'desc_inputs': [Tensor(np.arange(12).reshape(3, 2, 2), mstype.float32),
Tensor(np.arange(-12, 0).reshape(3, 2, 2), mstype.float32)],
'skip': ['backward'],
}),
('MatrixSetDiag', {
'block': inner.MatrixSetDiag(),
'desc_inputs': [Tensor(np.arange(12).reshape(3, 2, 2), mstype.float32),
Tensor(np.arange(6).reshape(3, 2), mstype.float32),
Tensor(np.arange(-12, 0).reshape(3, 2, 2), mstype.float32)],
'skip': ['backward'],
}),
] ]
test_case_other_ops = [ test_case_other_ops = [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册