提交 fe8b59f2 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1111 support vm for pack and unpack

Merge pull request !1111 from jiangjinsheng/vm_pack
...@@ -182,3 +182,5 @@ from .sgd import sgd_op_info ...@@ -182,3 +182,5 @@ from .sgd import sgd_op_info
from .lars_update import lars_update_op_info from .lars_update import lars_update_op_info
from .bn_training_update_v2 import _bn_training_update_v2_tbe from .bn_training_update_v2 import _bn_training_update_v2_tbe
from .square_sum_all import square_sum_all_op_info from .square_sum_all import square_sum_all_op_info
from .pack import _pack_tbe
from .unpack import _unpack_tbe
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Pack op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
pack_op_info = TBERegOp("Pack") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("pack.so") \
.compute_cost(10) \
.kernel_name("pack") \
.partial_flag(True) \
.attr("axis", "optional", "int", "all") \
.input(0, "x", False, "dynamic", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I8_NDHWC, DataType.I8_NDHWC) \
.dtype_format(DataType.I16_NDHWC, DataType.I16_NDHWC) \
.dtype_format(DataType.I32_NDHWC, DataType.I32_NDHWC) \
.dtype_format(DataType.I64_NDHWC, DataType.I64_NDHWC) \
.dtype_format(DataType.U8_NDHWC, DataType.U8_NDHWC) \
.dtype_format(DataType.U16_NDHWC, DataType.U16_NDHWC) \
.dtype_format(DataType.U32_NDHWC, DataType.U32_NDHWC) \
.dtype_format(DataType.U64_NDHWC, DataType.U64_NDHWC) \
.dtype_format(DataType.F16_NDHWC, DataType.F16_NDHWC) \
.dtype_format(DataType.F32_NDHWC, DataType.F32_NDHWC) \
.dtype_format(DataType.BOOL_NDHWC, DataType.BOOL_NDHWC) \
.get_op_info()
@op_info_register(pack_op_info)
def _pack_tbe():
"""Pack TBE register"""
return
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Unpack op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
unpack_op_info = TBERegOp("Unpack") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("unpack.so") \
.compute_cost(10) \
.kernel_name("unpack") \
.partial_flag(True) \
.attr("num", "optional", "int", "all") \
.attr("axis", "required", "int", "all") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "dynamic", "all") \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I8_5HD, DataType.I8_5HD) \
.dtype_format(DataType.I16_5HD, DataType.I16_5HD) \
.dtype_format(DataType.I32_5HD, DataType.I32_5HD) \
.dtype_format(DataType.I64_5HD, DataType.I64_5HD) \
.dtype_format(DataType.U8_5HD, DataType.U8_5HD) \
.dtype_format(DataType.U16_5HD, DataType.U16_5HD) \
.dtype_format(DataType.U32_5HD, DataType.U32_5HD) \
.dtype_format(DataType.U64_5HD, DataType.U64_5HD) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()
@op_info_register(unpack_op_info)
def _unpack_tbe():
"""Unpack TBE register"""
return
...@@ -499,6 +499,7 @@ class DataType: ...@@ -499,6 +499,7 @@ class DataType:
BOOL_NCHW = ("bool", "NCHW") BOOL_NCHW = ("bool", "NCHW")
BOOL_NHWC = ("bool", "NHWC") BOOL_NHWC = ("bool", "NHWC")
BOOL_HWCN = ("bool", "HWCN") BOOL_HWCN = ("bool", "HWCN")
BOOL_NDHWC = ("bool", "NDHWC")
I8_None = ("int8", "") I8_None = ("int8", "")
I8_Default = ("int8", "DefaultFormat") I8_Default = ("int8", "DefaultFormat")
...@@ -509,6 +510,7 @@ class DataType: ...@@ -509,6 +510,7 @@ class DataType:
I8_NCHW = ("int8", "NCHW") I8_NCHW = ("int8", "NCHW")
I8_NHWC = ("int8", "NHWC") I8_NHWC = ("int8", "NHWC")
I8_HWCN = ("int8", "HWCN") I8_HWCN = ("int8", "HWCN")
I8_NDHWC = ("int8", "NDHWC")
U8_None = ("uint8", "") U8_None = ("uint8", "")
U8_Default = ("uint8", "DefaultFormat") U8_Default = ("uint8", "DefaultFormat")
...@@ -519,6 +521,7 @@ class DataType: ...@@ -519,6 +521,7 @@ class DataType:
U8_NCHW = ("uint8", "NCHW") U8_NCHW = ("uint8", "NCHW")
U8_NHWC = ("uint8", "NHWC") U8_NHWC = ("uint8", "NHWC")
U8_HWCN = ("uint8", "HWCN") U8_HWCN = ("uint8", "HWCN")
U8_NDHWC = ("uint8", "NDHWC")
I16_None = ("int16", "") I16_None = ("int16", "")
I16_Default = ("int16", "DefaultFormat") I16_Default = ("int16", "DefaultFormat")
...@@ -529,6 +532,7 @@ class DataType: ...@@ -529,6 +532,7 @@ class DataType:
I16_NCHW = ("int16", "NCHW") I16_NCHW = ("int16", "NCHW")
I16_NHWC = ("int16", "NHWC") I16_NHWC = ("int16", "NHWC")
I16_HWCN = ("int16", "HWCN") I16_HWCN = ("int16", "HWCN")
I16_NDHWC = ("int16", "NDHWC")
U16_None = ("uint16", "") U16_None = ("uint16", "")
U16_Default = ("uint16", "DefaultFormat") U16_Default = ("uint16", "DefaultFormat")
...@@ -539,6 +543,7 @@ class DataType: ...@@ -539,6 +543,7 @@ class DataType:
U16_NCHW = ("uint16", "NCHW") U16_NCHW = ("uint16", "NCHW")
U16_NHWC = ("uint16", "NHWC") U16_NHWC = ("uint16", "NHWC")
U16_HWCN = ("uint16", "HWCN") U16_HWCN = ("uint16", "HWCN")
U16_NDHWC = ("uint16", "NDHWC")
I32_None = ("int32", "") I32_None = ("int32", "")
I32_Default = ("int32", "DefaultFormat") I32_Default = ("int32", "DefaultFormat")
...@@ -549,6 +554,7 @@ class DataType: ...@@ -549,6 +554,7 @@ class DataType:
I32_NCHW = ("int32", "NCHW") I32_NCHW = ("int32", "NCHW")
I32_NHWC = ("int32", "NHWC") I32_NHWC = ("int32", "NHWC")
I32_HWCN = ("int32", "HWCN") I32_HWCN = ("int32", "HWCN")
I32_NDHWC = ("int32", "NDHWC")
U32_None = ("uint32", "") U32_None = ("uint32", "")
U32_Default = ("uint32", "DefaultFormat") U32_Default = ("uint32", "DefaultFormat")
...@@ -559,6 +565,7 @@ class DataType: ...@@ -559,6 +565,7 @@ class DataType:
U32_NCHW = ("uint32", "NCHW") U32_NCHW = ("uint32", "NCHW")
U32_NHWC = ("uint32", "NHWC") U32_NHWC = ("uint32", "NHWC")
U32_HWCN = ("uint32", "HWCN") U32_HWCN = ("uint32", "HWCN")
U32_NDHWC = ("uint32", "NDHWC")
I64_None = ("int64", "") I64_None = ("int64", "")
I64_Default = ("int64", "DefaultFormat") I64_Default = ("int64", "DefaultFormat")
...@@ -569,6 +576,7 @@ class DataType: ...@@ -569,6 +576,7 @@ class DataType:
I64_NCHW = ("int64", "NCHW") I64_NCHW = ("int64", "NCHW")
I64_NHWC = ("int64", "NHWC") I64_NHWC = ("int64", "NHWC")
I64_HWCN = ("int64", "HWCN") I64_HWCN = ("int64", "HWCN")
I64_NDHWC = ("int64", "NDHWC")
U64_None = ("uint64", "") U64_None = ("uint64", "")
U64_Default = ("uint64", "DefaultFormat") U64_Default = ("uint64", "DefaultFormat")
...@@ -579,6 +587,7 @@ class DataType: ...@@ -579,6 +587,7 @@ class DataType:
U64_NCHW = ("uint64", "NCHW") U64_NCHW = ("uint64", "NCHW")
U64_NHWC = ("uint64", "NHWC") U64_NHWC = ("uint64", "NHWC")
U64_HWCN = ("uint64", "HWCN") U64_HWCN = ("uint64", "HWCN")
U64_NDHWC = ("uint64", "NDHWC")
F16_None = ("float16", "") F16_None = ("float16", "")
F16_Default = ("float16", "DefaultFormat") F16_Default = ("float16", "DefaultFormat")
...@@ -589,6 +598,7 @@ class DataType: ...@@ -589,6 +598,7 @@ class DataType:
F16_NCHW = ("float16", "NCHW") F16_NCHW = ("float16", "NCHW")
F16_NHWC = ("float16", "NHWC") F16_NHWC = ("float16", "NHWC")
F16_HWCN = ("float16", "HWCN") F16_HWCN = ("float16", "HWCN")
F16_NDHWC = ("float16", "NDHWC")
F32_None = ("float32", "") F32_None = ("float32", "")
F32_Default = ("float32", "DefaultFormat") F32_Default = ("float32", "DefaultFormat")
...@@ -599,6 +609,7 @@ class DataType: ...@@ -599,6 +609,7 @@ class DataType:
F32_NCHW = ("float32", "NCHW") F32_NCHW = ("float32", "NCHW")
F32_NHWC = ("float32", "NHWC") F32_NHWC = ("float32", "NHWC")
F32_HWCN = ("float32", "HWCN") F32_HWCN = ("float32", "HWCN")
F32_NDHWC = ("float32", "NDHWC")
F64_None = ("float64", "") F64_None = ("float64", "")
F64_Default = ("float64", "DefaultFormat") F64_Default = ("float64", "DefaultFormat")
...@@ -609,3 +620,4 @@ class DataType: ...@@ -609,3 +620,4 @@ class DataType:
F64_NCHW = ("float64", "NCHW") F64_NCHW = ("float64", "NCHW")
F64_NHWC = ("float64", "NHWC") F64_NHWC = ("float64", "NHWC")
F64_HWCN = ("float64", "HWCN") F64_HWCN = ("float64", "HWCN")
F64_NDHWC = ("float64", "NDHWC")
...@@ -227,6 +227,23 @@ class SpaceToBatchNet(Cell): ...@@ -227,6 +227,23 @@ class SpaceToBatchNet(Cell):
return self.space_to_batch(x) return self.space_to_batch(x)
class PackNet(Cell):
def __init__(self):
super(PackNet, self).__init__()
self.pack = P.Pack()
def construct(self, x):
return self.pack((x, x))
class UnpackNet(Cell):
def __init__(self):
super(UnpackNet, self).__init__()
self.unpack = P.Unpack()
def construct(self, x):
return self.unpack(x)
test_case_array_ops = [ test_case_array_ops = [
('CustNet1', { ('CustNet1', {
'block': CustNet1(), 'block': CustNet1(),
...@@ -249,6 +266,12 @@ test_case_array_ops = [ ...@@ -249,6 +266,12 @@ test_case_array_ops = [
('SpaceToBatchNet', { ('SpaceToBatchNet', {
'block': SpaceToBatchNet(), 'block': SpaceToBatchNet(),
'desc_inputs': [Tensor(np.array([[[[1, 2], [3, 4]]]]).astype(np.float16))]}), 'desc_inputs': [Tensor(np.array([[[[1, 2], [3, 4]]]]).astype(np.float16))]}),
('PackNet', {
'block': PackNet(),
'desc_inputs': [Tensor(np.array([[[1, 2], [3, 4]]]).astype(np.float16))]}),
('UnpackNet', {
'block': UnpackNet(),
'desc_inputs': [Tensor(np.array([[1, 2], [3, 4]]).astype(np.float16))]}),
] ]
test_case_lists = [test_case_array_ops] test_case_lists = [test_case_array_ops]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册