From a45dfca142b20a86e8bebba78203e2ee79442815 Mon Sep 17 00:00:00 2001 From: jiangjinsheng Date: Tue, 26 May 2020 15:02:47 +0800 Subject: [PATCH] support BatchToSpaceND and SpaceToBatchND --- mindspore/ccsrc/kernel/tbe/tbe_adapter.cc | 2 + mindspore/ops/_grad/grad_array_ops.py | 20 +++ mindspore/ops/_op_impl/tbe/__init__.py | 2 + .../ops/_op_impl/tbe/batch_to_space_nd.py | 38 +++++ .../ops/_op_impl/tbe/space_to_batch_nd.py | 38 +++++ mindspore/ops/operations/__init__.py | 5 +- mindspore/ops/operations/array_ops.py | 160 ++++++++++++++++++ tests/ut/python/ops/test_array_ops.py | 27 +++ 8 files changed, 291 insertions(+), 1 deletion(-) create mode 100644 mindspore/ops/_op_impl/tbe/batch_to_space_nd.py create mode 100644 mindspore/ops/_op_impl/tbe/space_to_batch_nd.py diff --git a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc index 64edeefcb..34ee3753f 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc @@ -82,6 +82,8 @@ static std::map tbe_func_adapter_map = { {"argmax", "arg_max_d"}, {"space_to_batch", "space_to_batch_d"}, {"batch_to_space", "batch_to_space_d"}, + {"space_to_batch_nd", "space_to_batch_nd_d"}, + {"batch_to_space_nd", "batch_to_space_nd_d"}, {"resize_bilinear", "resize_bilinear_v2_d"}, {"resize_bilinear_grad", "resize_bilinear_v2_grad"}, {"adam", "apply_adam"}, diff --git a/mindspore/ops/_grad/grad_array_ops.py b/mindspore/ops/_grad/grad_array_ops.py index aacb94ac6..e7f181ad7 100644 --- a/mindspore/ops/_grad/grad_array_ops.py +++ b/mindspore/ops/_grad/grad_array_ops.py @@ -536,3 +536,23 @@ def get_bprop_batch_to_space(self): dx = batch_to_space_grad(dout) return (dx,) return bprop + + +@bprop_getters.register(P.SpaceToBatchND) +def get_bprop_space_to_batch_nd(self): + """Generate bprop for SpaceToBatchND""" + space_to_batch_nd_grad = P.BatchToSpaceND(self.block_shape, self.paddings) + def bprop(x, out, dout): + dx = space_to_batch_nd_grad(dout) + return (dx,) + return bprop + + +@bprop_getters.register(P.BatchToSpaceND) +def get_bprop_batch_to_space_nd(self): + """Generate bprop for BatchToSpaceND""" + batch_to_space_nd_grad = P.SpaceToBatchND(self.block_shape, self.crops) + def bprop(x, out, dout): + dx = batch_to_space_nd_grad(dout) + return (dx,) + return bprop diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index 744245e2a..6051b4e1f 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -200,3 +200,5 @@ from .reduce_prod import _reduce_prod_tbe from .flatten_grad import _flatten_grad_tbe from .scatter_add import _scatter_add_tbe from .atan2 import _atan2_tbe +from .batch_to_space_nd import _batch_to_space_nd_tbe +from .space_to_batch_nd import _space_to_batch_nd_tbe diff --git a/mindspore/ops/_op_impl/tbe/batch_to_space_nd.py b/mindspore/ops/_op_impl/tbe/batch_to_space_nd.py new file mode 100644 index 000000000..ad5060e7c --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/batch_to_space_nd.py @@ -0,0 +1,38 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""BatchToSpaceND op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +batch_to_space_nd_op_info = TBERegOp("BatchToSpaceND") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("batch_to_space_nd_d.so") \ + .compute_cost(10) \ + .kernel_name("batch_to_space_nd_d") \ + .partial_flag(True) \ + .attr("block_shape", "required", "listInt", "all") \ + .attr("crops", "required", "listListInt", "all") \ + .input(0, "x", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ + .get_op_info() + + +@op_info_register(batch_to_space_nd_op_info) +def _batch_to_space_nd_tbe(): + """BatchToSpaceND TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/space_to_batch_nd.py b/mindspore/ops/_op_impl/tbe/space_to_batch_nd.py new file mode 100644 index 000000000..3a50b56a2 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/space_to_batch_nd.py @@ -0,0 +1,38 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""SpaceToBatchND op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +space_to_batch_nd_op_info = TBERegOp("SpaceToBatchND") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("space_to_batch_nd_d.so") \ + .compute_cost(10) \ + .kernel_name("space_to_batch_nd_d") \ + .partial_flag(True) \ + .attr("block_shape", "required", "listInt", "all") \ + .attr("paddings", "required", "listListInt", "all") \ + .input(0, "x", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ + .get_op_info() + + +@op_info_register(space_to_batch_nd_op_info) +def _space_to_batch_nd_tbe(): + """SpaceToBatchND TBE register""" + return diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 5af72eb03..5deb3623e 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -29,7 +29,8 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, Shape, Size, Slice, Split, Squeeze, StridedSlice, Tile, Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, - UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace) + UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace, + SpaceToBatchND, BatchToSpaceND) from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast, _MirrorOperator, ReduceOp, _VirtualDataset, _VirtualDiv, _GetTensorSlice) @@ -260,6 +261,8 @@ __all__ = [ "Atan2", "ApplyRMSProp", "ApplyCenteredRMSProp", + "SpaceToBatchND", + "BatchToSpaceND", "SquareSumAll" ] diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 280e24f7c..1b33533dc 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -2491,3 +2491,163 @@ class BatchToSpace(PrimitiveWithInfer): f'block_size_prod {block_size_prod}') out_shape[0] = out_shape[0] // block_size_prod return out_shape + + +class SpaceToBatchND(PrimitiveWithInfer): + r""" + Divide spatial dimensions into blocks and combine the block size with the original batch. + + This operation will divide spatial dimensions (H, W) into blocks with block_shape, the output tensor's H and W + dimension is the corresponding number of blocks after division. The output tensor's batch dimension is the + product of the original batch and the product of block_shape. Prior to division into blocks, the spatial dimensions + of the input are zero padded according to paddings if necessary. + + Args: + block_shape (Union[list(int), tuple(int)]): The block shape of dividing block with all value >= 1. + The length of block_shape is M correspoding to the number of spatial dimensions. + paddings (list): The padding value for H and W dimension, containing M sub list, each containing 2 int value. + All values must be >= 0. paddings[i] specifies the paddings for spatial dimension i, which corresponds to + input dimension i+2. It is required that input_shape[i+2]+paddings[i][0]+paddings[i][1] is divisible + by block_shape[i]. + + Inputs: + - **input_x** (Tensor) - The input tensor. + + Outputs: + Tensor, the output tensor with the same type as input. Assume input shape is :math:`(n, c, h, w)` with + :math:`block\_shape` and :math:`padddings`. The output tensor shape will be :math:`(n', c', h', w')`, where + + :math:`n' = n*(block\_shape[0]*block\_shape[1])` + + :math:`c' = c` + + :math:`h' = (h+paddings[0][0]+paddings[0][1])//block\_shape[0]` + + :math:`w' = (w+paddings[1][0]+paddings[1][1])//block\_shape[1]` + + Examples: + >>> block_shape = [2, 2] + >>> paddings = [[0, 0], [0, 0]] + >>> space_to_batch_nd = P.SpaceToBatchND(block_shape, paddings) + >>> input_x = Tensor(np.array([[[[1, 2], [3, 4]]]]), mindspore.float32) + >>> space_to_batch_nd(input_x) + [[[[1.]]], [[[2.]]], [[[3.]]], [[[4.]]]] + + """ + + @prim_attr_register + def __init__(self, block_shape, paddings): + """Init SpaceToBatchND""" + validator.check_value_type('block_shape type', block_shape, [list, tuple], self.name) + validator.check('block_shape shape', len(np.array(block_shape).shape), '', 1, Rel.EQ, self.name) + block_rank = len(block_shape) + + for elem in block_shape: + validator.check('block_shape element', elem, '', 1, Rel.GE, self.name) + self.block_shape = block_shape + + validator.check('paddings shape', np.array(paddings).shape, '', (block_rank, 2), Rel.EQ, self.name) + for elem in itertools.chain(*paddings): + validator.check_integer('paddings element', elem, 0, Rel.GE, self.name) + validator.check_value_type('paddings element', elem, [int], self.name) + self.paddings = paddings + + def infer_dtype(self, x_dtype): + validator.check_tensor_type_same({'input_x': x_dtype}, mstype.number_type, self.name) + return x_dtype + + def infer_shape(self, x_shape): + x_rank = len(x_shape) + out_shape = copy.deepcopy(x_shape) + + block_shape_prod = 1 + for i in range(x_rank - 2): + padded = out_shape[i + 2] + self.paddings[i][0] + \ + self.paddings[i][1] + if padded % self.block_shape[i] != 0: + raise ValueError(f'For \'{self.name}\' padded[{i}] {padded} should be divisible by ' + f'block_shape[{i}] {self.block_shape[i]}') + out_shape[i + 2] = padded // self.block_shape[i] + block_shape_prod = block_shape_prod * self.block_shape[i] + out_shape[0] *= block_shape_prod + return out_shape + + +class BatchToSpaceND(PrimitiveWithInfer): + r""" + Divide batch dimension with blocks and interleaves these blocks back into spatial dimensions. + + This operation will divide batch dimension N into blocks with block_shape, the output tensor's N dimension + is the corresponding number of blocks after division. The output tensor's H, W dimension is product of original H, W + dimension and block_shape with given amount to crop from dimension, respectively. + + Args: + block_shape (Union[list(int), tuple(int)]): The block shape of dividing block with all value >= 1. + The length of block_shape is M correspoding to the number of spatial dimensions. + crops (list): The crop value for H and W dimension, containing 2 sub list, each containing 2 int value. + All values must be >= 0. crops[i] specifies the crop values for spatial dimension i, which corresponds to + input dimension i+2. It is required that input_shape[i+2]*block_size[i] >= crops[i][0]+crops[i][1]. + + Inputs: + - **input_x** (Tensor) - The input tensor. + + Outputs: + Tensor, the output tensor with the same type as input. Assume input shape is (n, c, h, w) with block_shape + and crops. The output shape will be (n', c', h', w'), where + + :math:`n' = n//(block\_shape[0]*block\_shape[1])` + + :math:`c' = c` + + :math:`h' = h*block\_shape[0]-crops[0][0]-crops[0][1]` + + :math:`w' = w*block\_shape[1]-crops[1][0]-crops[1][1]` + + Examples: + >>> block_shape = [2, 2] + >>> crops = [[0, 0], [0, 0]] + >>> batch_to_space_nd = P.BatchToSpaceND(block_shape, crops) + >>> input_x = Tensor(np.array([[[[1]]], [[[2]]], [[[3]]], [[[4]]]]), mindspore.float32) + >>> output = batch_to_space_nd(input_x) + [[[[1., 2.], [3., 4.]]]] + + """ + + @prim_attr_register + def __init__(self, block_shape, crops): + """Init BatchToSpaceND""" + validator.check_value_type('block_shape type', block_shape, [list, tuple], self.name) + validator.check('block_shape shape', len(np.array(block_shape).shape), '', 1, Rel.EQ, self.name) + block_rank = len(block_shape) + + for elem in block_shape: + validator.check('block_shape element', elem, '', 1, Rel.GE, self.name) + self.block_shape = block_shape + + validator.check('crops shape', np.array(crops).shape, '', (block_rank, 2), Rel.EQ, self.name) + for elem in itertools.chain(*crops): + validator.check_integer('crops element', elem, 0, Rel.GE, self.name) + validator.check_value_type('crops element', elem, [int], self.name) + self.crops = crops + + def infer_dtype(self, x_dtype): + validator.check_tensor_type_same({'input_x': x_dtype}, mstype.number_type, self.name) + return x_dtype + + def infer_shape(self, x_shape): + x_rank = len(x_shape) + out_shape = copy.deepcopy(x_shape) + + block_shape_prod = 1 + for i in range(x_rank - 2): + block_shape_prod = block_shape_prod * self.block_shape[i] + x_block_prod = out_shape[i + 2] * self.block_shape[i] + crops_sum = self.crops[i][0] + self.crops[i][1] + validator.check("x block shape prod", x_block_prod, 'crops sum', crops_sum, Rel.GT, self.name) + out_shape[i + 2] = x_block_prod - crops_sum + + if out_shape[0] % block_shape_prod != 0: + raise ValueError(f'For \'{self.name}\' input_x dimension 0 {out_shape[0]} should be divisible by ' + f'block_shape_prod {block_shape_prod}') + out_shape[0] = out_shape[0] // block_shape_prod + return out_shape diff --git a/tests/ut/python/ops/test_array_ops.py b/tests/ut/python/ops/test_array_ops.py index 3ade4b983..5e5fa7deb 100644 --- a/tests/ut/python/ops/test_array_ops.py +++ b/tests/ut/python/ops/test_array_ops.py @@ -264,6 +264,27 @@ class DepthToSpaceNet(Cell): return self.depth_to_space(x) +class BatchToSpaceNDNet(Cell): + def __init__(self): + super(BatchToSpaceNDNet, self).__init__() + block_shape = [2, 2] + crops = [[0, 0], [0, 0]] + self.batch_to_space_nd = P.BatchToSpaceND(block_shape, crops) + + def construct(self, x): + return self.batch_to_space_nd(x) + + +class SpaceToBatchNDNet(Cell): + def __init__(self): + super(SpaceToBatchNDNet, self).__init__() + block_shape = [2, 2] + paddings = [[0, 0], [0, 0]] + self.space_to_batch_nd = P.SpaceToBatchND(block_shape, paddings) + + def construct(self, x): + return self.space_to_batch_nd(x) + test_case_array_ops = [ ('CustNet1', { 'block': CustNet1(), @@ -298,6 +319,12 @@ test_case_array_ops = [ ('DepthToSpaceNet', { 'block': DepthToSpaceNet(), 'desc_inputs': [Tensor(np.random.rand(1,12,1,1).astype(np.float16))]}), + ('SpaceToBatchNDNet', { + 'block': SpaceToBatchNDNet(), + 'desc_inputs': [Tensor(np.random.rand(1,1,2,2).astype(np.float16))]}), + ('BatchToSpaceNDNet', { + 'block': BatchToSpaceNDNet(), + 'desc_inputs': [Tensor(np.random.rand(4,1,1,1).astype(np.float16))]}), ] test_case_lists = [test_case_array_ops] -- GitLab