From 62807da0c4fa6caa0c941524483959c0263f51c5 Mon Sep 17 00:00:00 2001 From: buxue Date: Wed, 25 Mar 2020 18:38:10 +0800 Subject: [PATCH] =?UTF-8?q?Develop=20operator=20Unfold=EF=BC=8Ctake=20the?= =?UTF-8?q?=20ge=20backend=EF=BC=8Cdock=20with=20tbe's=20ExtractImagePatch?= =?UTF-8?q?es=20operator?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mindspore/ccsrc/kernel/tbe/tbe_adapter.h | 6 -- mindspore/ccsrc/transform/convert.cc | 2 + mindspore/ccsrc/transform/op_adapter.h | 26 +++--- mindspore/ccsrc/transform/op_declare.cc | 14 ++-- mindspore/ccsrc/transform/op_declare.h | 2 + mindspore/ccsrc/utils/utils.h | 1 + mindspore/nn/layer/__init__.py | 4 +- mindspore/nn/layer/basic.py | 48 +++++++++++ mindspore/ops/_grad/grad_nn_ops.py | 57 ++++++++++++- mindspore/ops/operations/__init__.py | 3 +- mindspore/ops/operations/math_ops.py | 2 +- mindspore/ops/operations/nn_ops.py | 100 ++++++++++++++++++++++- tests/ut/python/ops/test_math_ops.py | 32 ++++++-- tests/ut/python/ops/test_nn_ops.py | 55 +++++++++++++ 14 files changed, 310 insertions(+), 42 deletions(-) diff --git a/mindspore/ccsrc/kernel/tbe/tbe_adapter.h b/mindspore/ccsrc/kernel/tbe/tbe_adapter.h index 3997318c8..27f6d315f 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_adapter.h +++ b/mindspore/ccsrc/kernel/tbe/tbe_adapter.h @@ -45,12 +45,6 @@ class TbeAdapter { std::vector *input_list, kCreaterType creater_type); private: - static void MaxPoolWithArgmaxAttrJsonPass(const AnfNodePtr &anf_node, - const std::vector> &op_info_attrs, - nlohmann::json *attrs_json); - static void MaxPoolGradWithArgmaxAttrJsonPass(const AnfNodePtr &anf_node, - const std::vector> &op_info_attrs, - nlohmann::json *attrs_json); static void Conv2DAttrJsonPass(const AnfNodePtr &anf_node, const std::vector> &op_info_attrs, nlohmann::json *attrs_json); static void Conv2DBackpropFilterAttrJsonPass(const AnfNodePtr &anf_node, diff --git a/mindspore/ccsrc/transform/convert.cc b/mindspore/ccsrc/transform/convert.cc index 251946f6f..2daa86b96 100755 --- a/mindspore/ccsrc/transform/convert.cc +++ b/mindspore/ccsrc/transform/convert.cc @@ -96,6 +96,7 @@ const char kNameConfusionMatrix[] = "ConfusionMatrix"; const char kNameResizeNearestNeighborD[] = "ResizeNearestNeighbor"; const char kNameResizeNearestNeighborGrad[] = "ResizeNearestNeighborGrad"; const char kNameApplyAdam[] = "Adam"; +const char kNameExtractImagePatches[] = "ExtractImagePatches"; const char kNameReLU6[] = "ReLU6"; const char kNameReLU6Grad[] = "ReLU6Grad"; const char kNameElu[] = "Elu"; @@ -214,6 +215,7 @@ std::unordered_map &DfGraphConvertor::get_adpt_ma {string(kNameMaxPoolGrad), ADPT_DESC(MaxPoolGrad)}, {string(kNameAvgPoolGrad), ADPT_DESC(AvgPoolGrad)}, {string(kNameMaxPoolGradWithArgmax), ADPT_DESC(MaxPoolGradWithArgmax)}, + {string(kNameExtractImagePatches), ADPT_DESC(ExtractImagePatches)}, {prim::kPrimAssign->name(), ADPT_DESC(Assign)}, {prim::kPrimStateSetItem->name(), ADPT_DESC(Assign)}, {prim::kPrimReluGrad->name(), ADPT_DESC(ReluGrad)}, diff --git a/mindspore/ccsrc/transform/op_adapter.h b/mindspore/ccsrc/transform/op_adapter.h index 7f20a8803..421e4c456 100644 --- a/mindspore/ccsrc/transform/op_adapter.h +++ b/mindspore/ccsrc/transform/op_adapter.h @@ -322,18 +322,12 @@ class OpAdapter : public BaseOpAdapter { Status UpdateSingleOutputDesc(const OperatorPtr& op, const abstract::BaseShapePtr& shp, const TypePtr& type) { MS_EXCEPTION_IF_NULL(type); - TypeId me_type = type->type_id(); - if (kObjectTypeTensorType == me_type) { - me_type = dyn_cast(type)->element()->type_id(); - } - - std::vector shape; - auto normal_shape_ptr = dyn_cast(shp); - if (nullptr != normal_shape_ptr) { - shape = normal_shape_ptr->shape(); + std::string format = "NCHW"; + if (op->GetOpType() == kExtractImagePatchesOpName) { + format = "NHWC"; } - auto desc = TransformUtil::GetGeTensorDesc(shape, me_type, "NCHW"); + auto desc = CreateOutputDesc(dyn_cast(shp), type, format); if (desc == nullptr) { MS_LOG(ERROR) << "Update output descriptor failed!"; return FAILED; @@ -410,14 +404,15 @@ class OpAdapter : public BaseOpAdapter { MS_LOG(ERROR) << "output_map is not equal tuple_shape size"; return FAILED; } + std::string format = "NCHW"; + if (op->GetOpType() == kTopKOpName) { + format = "NHWC"; + } for (size_t i = 0; i < tuple_shp->shape().size(); ++i) { auto tuple_type = dyn_cast(type); MS_EXCEPTION_IF_NULL(tuple_type); TypePtr type_elem = tuple_type->elements()[i]; - std::string format = "NCHW"; - if (op->GetOpType() == kTopKOpName) { - format = "NHWC"; - } + auto desc = CreateOutputDesc(dyn_cast(tuple_shp->shape()[i]), type_elem, format); if (desc == nullptr) { MS_LOG(ERROR) << "Create output descriptor failed!"; @@ -476,6 +471,9 @@ class OpAdapter : public BaseOpAdapter { if (desc == nullptr) { continue; } + if (op->GetOpType() == kExtractImagePatchesOpName) { + desc->SetFormat(ge::Format::FORMAT_NHWC); + } it->second.update_input_desc(op, *desc); } } diff --git a/mindspore/ccsrc/transform/op_declare.cc b/mindspore/ccsrc/transform/op_declare.cc index f821c71d8..420edc685 100644 --- a/mindspore/ccsrc/transform/op_declare.cc +++ b/mindspore/ccsrc/transform/op_declare.cc @@ -751,16 +751,20 @@ ATTR_MAP(MaxPoolWithArgmax) = {{"ksize", ATTR_DESC(ksize, AnyTraits(), AnyT OUTPUT_MAP(MaxPoolWithArgmax) = {{0, OUTPUT_DESC(y)}, {1, OUTPUT_DESC(argmax)}}; // MaxPoolGradWithArgmax -INPUT_MAP(MaxPoolGradWithArgmax) = { - {1, INPUT_DESC(x)}, - {2, INPUT_DESC(grad)}, - {3, INPUT_DESC(argmax)}, -}; +INPUT_MAP(MaxPoolGradWithArgmax) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(grad)}, {3, INPUT_DESC(argmax)}}; ATTR_MAP(MaxPoolGradWithArgmax) = {{"ksize", ATTR_DESC(ksize, AnyTraits(), AnyTraits>())}, {"strides", ATTR_DESC(strides, AnyTraits(), AnyTraits>())}, {"padding", ATTR_DESC(padding, AnyTraits())}}; OUTPUT_MAP(MaxPoolGradWithArgmax) = {{0, OUTPUT_DESC(y)}}; +// ExtractImagePatches +INPUT_MAP(ExtractImagePatches) = {{1, INPUT_DESC(images)}}; +ATTR_MAP(ExtractImagePatches) = {{"ksizes", ATTR_DESC(ksizes, AnyTraits(), AnyTraits>())}, + {"strides", ATTR_DESC(strides, AnyTraits(), AnyTraits>())}, + {"rates", ATTR_DESC(rates, AnyTraits(), AnyTraits>())}, + {"padding", ATTR_DESC(padding, AnyTraits())}}; +OUTPUT_MAP(ExtractImagePatches) = {{0, OUTPUT_DESC(y)}}; + // Conv2D INPUT_MAP(Conv2D) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(filter)}}; ATTR_MAP(Conv2D) = { diff --git a/mindspore/ccsrc/transform/op_declare.h b/mindspore/ccsrc/transform/op_declare.h index 8f6dda943..8b32e16b3 100755 --- a/mindspore/ccsrc/transform/op_declare.h +++ b/mindspore/ccsrc/transform/op_declare.h @@ -95,6 +95,8 @@ DECLARE_OP_USE_OUTPUT(MaxPoolGradWithArgmax) DECLARE_OP_ADAPTER(Conv2D) DECLARE_OP_USE_ENUM(Conv2D) DECLARE_OP_USE_OUTPUT(Conv2D) +DECLARE_OP_ADAPTER(ExtractImagePatches) +DECLARE_OP_USE_OUTPUT(ExtractImagePatches) DECLARE_OP_ADAPTER(Conv2DBackpropInputD) DECLARE_OP_USE_ENUM(Conv2DBackpropInputD) DECLARE_OP_USE_INPUT_ATTR(Conv2DBackpropInputD) diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 08a98a312..44e7b4d4c 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -49,6 +49,7 @@ constexpr auto kBroadcastOpName = "Broadcast"; constexpr auto kReduceScatterOpName = "ReduceScatter"; constexpr auto kMemCpyAsyncOpName = "memcpy_async"; constexpr auto kTopKOpName = "TopK"; +constexpr auto kExtractImagePatchesOpName = "ExtractImagePatches"; constexpr auto kBNTrainingReduceOpName = "BNTrainingReduce"; constexpr auto kBNTrainingUpdateOpName = "BNTrainingUpdate"; constexpr auto kSimpleMeanGradOpName = "SimpleMeanGrad"; diff --git a/mindspore/nn/layer/__init__.py b/mindspore/nn/layer/__init__.py index 9c2c30c91..3d729edcd 100644 --- a/mindspore/nn/layer/__init__.py +++ b/mindspore/nn/layer/__init__.py @@ -22,7 +22,7 @@ from .normalization import BatchNorm1d, BatchNorm2d, LayerNorm from .container import SequentialCell, CellList from .conv import Conv2d, Conv2dTranspose from .lstm import LSTM -from .basic import Dropout, Flatten, Dense, ClipByNorm, Norm, OneHot, Pad +from .basic import Dropout, Flatten, Dense, ClipByNorm, Norm, OneHot, Pad, Unfold from .embedding import Embedding from .pooling import AvgPool2d, MaxPool2d from .image import ImageGradients, SSIM @@ -35,6 +35,6 @@ __all__ = ['Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'Sigmoid', 'LSTM', 'Dropout', 'Flatten', 'Dense', 'ClipByNorm', 'Norm', 'OneHot', 'Embedding', - 'AvgPool2d', 'MaxPool2d', 'Pad', + 'AvgPool2d', 'MaxPool2d', 'Pad', 'Unfold', 'ImageGradients', 'SSIM', ] diff --git a/mindspore/nn/layer/basic.py b/mindspore/nn/layer/basic.py index 64c4cfd93..5ac52acac 100644 --- a/mindspore/nn/layer/basic.py +++ b/mindspore/nn/layer/basic.py @@ -439,3 +439,51 @@ class Pad(Cell): else: x = self.pad(x, self.paddings) return x + + +class Unfold(Cell): + """ + Extract patches from images. + The input tensor must be a 4-D tensor and the data format is NCHW. + + Args: + ksizes (Union[tuple[int], list[int]]): The size of sliding window, should be a tuple or list of int, + and the format is [1, ksize_row, ksize_col, 1]. + strides (Union[tuple[int], list[int]]): Distance between the centers of the two consecutive patches, + should be a tuple or list of int, and the format is [1, stride_row, stride_col, 1]. + rates (Union[tuple[int], list[int]]): In each extracted patch, the gap between the corresponding dim + pixel positions, should be a tuple or list of int, and the format is [1, rate_row, rate_col, 1]. + padding (str): The type of padding algorithm, is a string whose value is "same" or "valid", + not case sensitive. Default: "valid". + + - same: Means that the patch can take the part beyond the original image, and this part is filled with 0. + + - valid: Means that the patch area taken must be completely contained in the original image. + + Inputs: + - **input_x** (Tensor) - A 4-D tensor whose shape is [in_batch, in_depth, in_row, in_col] and + data type is int8, float16, uint8. + + Outputs: + Tensor, a 4-D tensor whose data type is same as 'input_x', + and the shape is [out_batch, out_depth, out_row, out_col], the out_batch is same as the in_batch. + + Examples: + >>> net = Unfold(ksizes=[1, 2, 2, 1], strides=[1, 1, 1, 1], rates=[1, 1, 1, 1]) + >>> image = Tensor(np.ones([1, 1, 3, 3]), dtype=mstype.float16) + >>> net(image) + Tensor ([[[[1, 1] [1, 1]] [[1, 1], [1, 1]] [[1, 1] [1, 1]], [[1, 1], [1, 1]]]], + shape=(1, 4, 2, 2), dtype=mstype.float16) + """ + def __init__(self, ksizes, strides, rates, padding="valid"): + super(Unfold, self).__init__() + self.extract_image_patches = P.ExtractImagePatches(ksizes, strides, rates, padding) + self.transpose = P.Transpose() + self.format_NHWC = (0, 2, 3, 1) + self.format_NCHW = (0, 3, 1, 2) + + def construct(self, input_x): + x_transpose = self.transpose(input_x, self.format_NHWC) + ret = self.extract_image_patches(x_transpose) + ret_transpose = self.transpose(ret, self.format_NCHW) + return ret_transpose diff --git a/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/ops/_grad/grad_nn_ops.py index 149dd6cae..ae730d78a 100755 --- a/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/ops/_grad/grad_nn_ops.py @@ -14,7 +14,7 @@ # ============================================================================ """Define the grad rules of neural network related operations.""" - +from mindspore.common import dtype as mstype from .. import functional as F from .. import operations as P from ..operations import _grad_ops as G @@ -52,6 +52,61 @@ def get_bprop_conv2d(self): return bprop +@bprop_getters.register(P.ExtractImagePatches) +def get_bprop_extract_image_patches(self): + """Grad definition for `ExtractImagePatches` operation.""" + get_shape = P.Shape() + reshape = P.Reshape() + extract_image_patches = P.ExtractImagePatches(ksizes=self.ksizes, + strides=self.strides, + rates=self.rates, + padding=self.padding) + concat = P.Concat(axis=-1) + expand_dims = P.ExpandDims() + scatter_nd = P.ScatterNd() + dtype = P.DType() + fill = P.Fill() + slice_op = P.Slice() + transpose = P.Transpose() + matmul = P.MatMul() + cast = P.Cast() + _, ksizes_row, ksizes_col, _ = self.ksizes + + def bprop(x, out, dout): + x_shape = get_shape(x) + x_batch, x_row, x_col, x_depth = x_shape + x_indices_num = x_row * x_col + 1 + x_idx = F.tuple_to_array(range(1, x_indices_num)) + x_idx = reshape(x_idx, (1, x_row, x_col, 1)) + x_idx = cast(x_idx, mstype.float16) + x_idx_patch = extract_image_patches(x_idx) + x_idx_patch = transpose(x_idx_patch, (0, 3, 1, 2)) + x_idx_patch = cast(x_idx_patch, mstype.int32) + + out_shape = get_shape(out) + _, out_row, out_col, _ = out_shape + out_indices_num = out_row * out_col * ksizes_row * ksizes_col + out_idx = F.tuple_to_array(range(out_indices_num)) + out_idx = reshape(out_idx, (1, ksizes_row * ksizes_col, out_row, out_col)) + + idx_tensor = concat((expand_dims(x_idx_patch, -1), expand_dims(out_idx, -1))) + idx_tensor = reshape(idx_tensor, (-1, 2)) + sp_shape = (x_indices_num, out_indices_num) + sp_tensor = scatter_nd(idx_tensor, fill(dtype(dout), (out_indices_num,), 1), sp_shape) + sp_tensor = slice_op(sp_tensor, (1, 0), (x_indices_num - 1, out_indices_num)) + + grad = reshape(dout, (x_batch, out_row, out_col, ksizes_row, ksizes_col, x_depth)) + grad = transpose(grad, (1, 2, 3, 4, 0, 5)) + grad = reshape(grad, (-1, x_batch * x_depth)) + + jac = matmul(sp_tensor, grad) + dx = reshape(jac, (x_row, x_col, x_batch, x_depth)) + dx = transpose(dx, (2, 0, 1, 3)) + + return (dx,) + return bprop + + @bprop_getters.register(P.DepthwiseConv2dNative) def get_bprop_depthwise_conv2d_native(self): """Grad definition for `DepthwiseConv2dNative` operation.""" diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 40cbfc338..492ebae44 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -57,7 +57,7 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm, Gelu, Elu, GetNext, L2Normalize, LayerNorm, LogSoftmax, - MaxPool, + MaxPool, ExtractImagePatches, AvgPool, Conv2DBackpropInput, MaxPoolWithArgmax, OneHot, Pad, MirrorPad, PReLU, ReLU, ReLU6, HSwish, HSigmoid, ResizeBilinear, Sigmoid, @@ -89,6 +89,7 @@ __all__ = [ 'Sqrt', 'Square', 'Conv2D', + 'ExtractImagePatches', 'Flatten', 'MaxPoolWithArgmax', 'FusedBatchNorm', diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index f6feb1af1..e390b6b58 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -1475,7 +1475,7 @@ class LogicalNot(PrimitiveWithInfer): Computes the "logical NOT" of a tensor element-wise. Inputs: - - **input_x** (Tensor) - The input tensor whose dtype is bool + - **input_x** (Tensor) - The input tensor whose dtype is bool. Outputs: Tensor, the shape is same as the `input_x`, and the dtype is bool. diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 377ef1941..a92a75c78 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -2546,6 +2546,7 @@ class ApplyFtrl(PrimitiveWithInfer): Outputs: Tensor, representing the updated var. """ + @prim_attr_register def __init__(self, use_locking=False): self.init_prim_io_names(inputs=['var', 'accum', 'linear', 'grad', 'lr', 'l1', 'l2', 'lr_power'], @@ -2566,8 +2567,99 @@ class ApplyFtrl(PrimitiveWithInfer): args = {'var_type': var_type, 'accum_type': accum_type, 'linear_type': linear_type, 'grad_type': grad_type} validator.check_type_same(args, (mstype.float32, mstype.float16)) - validator.check_typename("lr", lr_type,[mstype.float16, mstype.float32]) - validator.check_typename("l1", l1_type,[mstype.float16, mstype.float32]) - validator.check_typename("l2", l2_type,[mstype.float16, mstype.float32]) - validator.check_typename("lr_power", lr_power_type,[mstype.float16, mstype.float32]) + validator.check_typename("lr", lr_type, [mstype.float16, mstype.float32]) + validator.check_typename("l1", l1_type, [mstype.float16, mstype.float32]) + validator.check_typename("l2", l2_type, [mstype.float16, mstype.float32]) + validator.check_typename("lr_power", lr_power_type, [mstype.float16, mstype.float32]) return var_type + + +class ExtractImagePatches(PrimitiveWithInfer): + """ + Extract patches from images. + The input tensor must be a 4-D tensor and the data format is NHWC. + + Args: + ksizes (Union[tuple[int], list[int]]): The size of sliding window, should be a tuple or list of int, + and the format is [1, ksize_row, ksize_col, 1]. + strides (Union[tuple[int], list[int]]): Distance between the centers of the two consecutive patches, + should be a tuple or list of int, and the format is [1, stride_row, stride_col, 1]. + rates (Union[tuple[int], list[int]]): In each extracted patch, the gap between the corresponding dim + pixel positions, should be a tuple or list of int, and the format is [1, rate_row, rate_col, 1]. + padding (str): The type of padding algorithm, is a string whose value is "same" or "valid", + not case sensitive. Default: "valid". + + - same: Means that the patch can take the part beyond the original image, and this part is filled with 0. + + - valid: Means that the patch area taken must be completely contained in the original image. + + Inputs: + - **input_x** (Tensor) - A 4-D tensor whose shape is [in_batch, in_row, in_col, in_depth] and + data type is int8, float16, uint8. + + Outputs: + Tensor, a 4-D tensor whose data type is same as 'input_x', + and the shape is [out_batch, out_row, out_col, out_depth], the out_batch is same as the in_batch. + """ + + @prim_attr_register + def __init__(self, ksizes, strides, rates, padding="valid"): + """init""" + validator.check_type("ksizes", ksizes, [tuple, list]) + validator.check_type("strides", strides, [tuple, list]) + validator.check_type("rates", rates, [tuple, list]) + self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME']) + self.add_prim_attr("padding", self.padding) + + if len(ksizes) != 4 or ksizes[0] != 1 or ksizes[3] != 1: + raise ValueError("The format of ksizes should be [1, ksize_row, ksize_col, 1], " + f"but got {ksizes}.") + if not isinstance(ksizes[1], int) or not isinstance(ksizes[2], int) or \ + ksizes[1] < 1 or ksizes[2] < 1: + raise ValueError("The ksize_row and ksize_col in ksizes should be an positive integer number, " + f"but got ksize_row is {ksizes[1]}, ksize_col is {ksizes[2]}") + + if len(strides) != 4 or strides[0] != 1 or strides[3] != 1: + raise ValueError("The format of strides should be [1, stride_row, stride_col, 1], " + f"but got {strides}.") + if not isinstance(strides[1], int) or not isinstance(strides[2], int) or \ + strides[1] < 1 or strides[2] < 1: + raise ValueError("The stride_row and stride_col in strides should be an positive integer number, " + f"but got stride_row is {strides[1]}, stride_col is {strides[2]}") + + if len(rates) != 4 or rates[0] != 1 or rates[3] != 1: + raise ValueError("The format of rates should be [1, rate_row, rate_col, 1], " + f"but got {rates}.") + if not isinstance(rates[1], int) or not isinstance(rates[2], int) or \ + rates[1] < 1 or rates[2] < 1: + raise ValueError("The rate_row and rate_col in rates should be an positive integer number, " + f"but got rate_row is {rates[1]}, rate_col is {rates[2]}") + + def infer_shape(self, input_x): + in_batch, in_row, in_col, in_depth = input_x + _, ksize_row, ksize_col, _ = self.ksizes + _, stride_row, stride_col, _ = self.strides + _, rate_row, rate_col, _ = self.rates + if len(input_x) != 4: + raise ValueError("The `input_x` should be a 4-D tensor, " + f"but got a {len(input_x)}-D tensor whose shape is {input_x}") + + out_batch = in_batch + out_depth = ksize_row * ksize_col * in_depth + + if self.padding == "VALID": + out_row = \ + (in_row - (ksize_row + (ksize_row - 1) * (rate_row - 1))) // stride_row + 1 + out_col = \ + (in_col - (ksize_col + (ksize_col - 1) * (rate_col - 1))) // stride_col + 1 + else: + out_row = (in_row - 1) // stride_row + 1 + out_col = (in_col - 1) // stride_col + 1 + + out_shape = [out_batch, out_row, out_col, out_depth] + return out_shape + + def infer_dtype(self, input_x): + validator.check_subclass("input_x", input_x, mstype.tensor) + validator.check_typename("input_x_dtype", input_x, (mstype.int8, mstype.float16, mstype.float32)) + return input_x diff --git a/tests/ut/python/ops/test_math_ops.py b/tests/ut/python/ops/test_math_ops.py index ad1642228..8b7f627e8 100755 --- a/tests/ut/python/ops/test_math_ops.py +++ b/tests/ut/python/ops/test_math_ops.py @@ -30,6 +30,8 @@ from ....mindspore_test_framework.pipeline.forward.compile_forward \ import pipeline_for_compile_forward_ge_graph_for_case_by_case_config from ....mindspore_test_framework.pipeline.forward.verify_exception \ import pipeline_for_verify_exception_for_case_by_case_config + + # pylint: disable=W0613 # pylint: disable=W0231 # W0613: unused-argument @@ -106,7 +108,7 @@ def test_realdiv(): result = div(x, y) x = x.asnumpy() y = y.asnumpy() - expect = x/y + expect = x / y assert np.all(result.asnumpy() == expect) @@ -122,6 +124,7 @@ def test_eye(): class VirtualLossGrad(PrimitiveWithInfer): """ VirtualLossGrad definition """ + @prim_attr_register def __init__(self): """init VirtualLossGrad""" @@ -138,6 +141,7 @@ class VirtualLossGrad(PrimitiveWithInfer): class VirtualLoss(PrimitiveWithInfer): """ VirtualLoss definition """ + @prim_attr_register def __init__(self): """init VirtualLoss""" @@ -151,6 +155,7 @@ class VirtualLoss(PrimitiveWithInfer): def bprop(x, out, dout): dx = loss_grad(x, out, dout) return (dx,) + return bprop def infer_shape(self, x_shape): @@ -162,6 +167,7 @@ class VirtualLoss(PrimitiveWithInfer): class NetWithLoss(nn.Cell): """ NetWithLoss definition """ + def __init__(self, network): super(NetWithLoss, self).__init__() self.loss = VirtualLoss() @@ -174,6 +180,7 @@ class NetWithLoss(nn.Cell): class GradWrap(nn.Cell): """ GradWrap definition """ + def __init__(self, network): super(GradWrap, self).__init__() self.network = network @@ -184,6 +191,7 @@ class GradWrap(nn.Cell): class MatMulNet(nn.Cell): """ MatMulNet definition """ + def __init__(self): super(MatMulNet, self).__init__() self.matmul = P.MatMul() @@ -195,6 +203,7 @@ class MatMulNet(nn.Cell): class NetWithLossSub(nn.Cell): """ NetWithLossSub definition """ + def __init__(self, network): super(NetWithLossSub, self).__init__() self.loss = VirtualLoss() @@ -207,6 +216,7 @@ class NetWithLossSub(nn.Cell): class GradWrapSub(nn.Cell): """ GradWrapSub definition """ + def __init__(self, network): super(GradWrapSub, self).__init__() self.network = network @@ -217,6 +227,7 @@ class GradWrapSub(nn.Cell): class SubNet(nn.Cell): """ SubNet definition """ + def __init__(self): super(SubNet, self).__init__() self.sub = P.Sub() @@ -227,6 +238,7 @@ class SubNet(nn.Cell): class NpuFloatNet(nn.Cell): """ NpuFloat definition """ + def __init__(self): super(NpuFloatNet, self).__init__() self.mul = P.Mul() @@ -258,6 +270,7 @@ class NpuFloatNet(nn.Cell): class DiagNet(nn.Cell): """ DiagNet definition """ + def __init__(self): super(DiagNet, self).__init__() self.fill = P.Fill() @@ -269,6 +282,7 @@ class DiagNet(nn.Cell): class NetWithLossCumSum(nn.Cell): """ NetWithLossCumSum definition """ + def __init__(self, network): super(NetWithLossCumSum, self).__init__() self.loss = VirtualLoss() @@ -281,6 +295,7 @@ class NetWithLossCumSum(nn.Cell): class GradWrapCumSum(nn.Cell): """ GradWrap definition """ + def __init__(self, network): super(GradWrapCumSum, self).__init__() self.network = network @@ -291,6 +306,7 @@ class GradWrapCumSum(nn.Cell): class NetCumSum(nn.Cell): """ NetCumSum definition """ + def __init__(self): super(NetCumSum, self).__init__() self.cumsum = P.CumSum() @@ -321,8 +337,8 @@ test_case_math_ops = [ 'skip': ['backward']}), ('CumSumGrad', { 'block': GradWrapCumSum(NetWithLossCumSum(NetCumSum())), - 'desc_inputs': [Tensor(np.array([[3, 4, 6, 10],[1, 6, 7, 9],[4, 3, 8, 7],[1, 3, 7, 9]]).astype(np.float16))], - 'desc_bprop': [Tensor(np.array([[3, 4, 6, 10],[1, 6, 7, 9],[4, 3, 8, 7],[1, 3, 7, 9]]).astype(np.float16))], + 'desc_inputs': [Tensor(np.array([[3, 4, 6, 10], [1, 6, 7, 9], [4, 3, 8, 7], [1, 3, 7, 9]]).astype(np.float16))], + 'desc_bprop': [Tensor(np.array([[3, 4, 6, 10], [1, 6, 7, 9], [4, 3, 8, 7], [1, 3, 7, 9]]).astype(np.float16))], 'skip': ['backward']}), ('Diag', { 'block': DiagNet(), @@ -351,7 +367,6 @@ test_case_math_ops = [ 'skip': ['backward']}), ] - test_case_lists = [test_case_math_ops] test_exec_case = functools.reduce(lambda x, y: x + y, test_case_lists) # use -k to select certain testcast @@ -360,6 +375,7 @@ test_exec_case = functools.reduce(lambda x, y: x + y, test_case_lists) import mindspore.context as context + @non_graph_engine @mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config) def test_exec(): @@ -369,16 +385,16 @@ def test_exec(): raise_set = [ ('StridedSlice_1_Error', { - 'block': (lambda x : P.StridedSlice(begin_mask="1"), {'exception': ValueError}), + 'block': (lambda x: P.StridedSlice(begin_mask="1"), {'exception': ValueError}), 'desc_inputs': [0]}), ('StridedSlice_2_Error', { - 'block': (lambda x : P.StridedSlice(end_mask="1"), {'exception': ValueError}), + 'block': (lambda x: P.StridedSlice(end_mask="1"), {'exception': ValueError}), 'desc_inputs': [0]}), ('StridedSlice_3_Error', { - 'block': (lambda x : P.StridedSlice(ellipsis_mask=1.1), {'exception': ValueError}), + 'block': (lambda x: P.StridedSlice(ellipsis_mask=1.1), {'exception': ValueError}), 'desc_inputs': [0]}), ('StridedSlice_4_Error', { - 'block': (lambda x : P.StridedSlice(new_axis_mask="1.1"), {'exception': ValueError}), + 'block': (lambda x: P.StridedSlice(new_axis_mask="1.1"), {'exception': ValueError}), 'desc_inputs': [0]}), ] diff --git a/tests/ut/python/ops/test_nn_ops.py b/tests/ut/python/ops/test_nn_ops.py index cadac6dfb..736489350 100644 --- a/tests/ut/python/ops/test_nn_ops.py +++ b/tests/ut/python/ops/test_nn_ops.py @@ -382,6 +382,46 @@ def test_max_pool_with_arg_max(): print(ret) +class GradWrapUnfold(nn.Cell): + """ GradWrapUnfold definition """ + + def __init__(self, network): + super(GradWrapUnfold, self).__init__() + self.network = network + self.sens = Tensor(np.ones([1, 4, 2, 2], np.float32)) + + def construct(self, x): + return C.grad_all_with_sens(self.network)(x, self.sens) + + +class UnfoldNetValid(nn.Cell): + """ UnfoldNetValid definition """ + + def __init__(self): + super(UnfoldNetValid, self).__init__() + self.unfold = nn.Unfold(ksizes=[1, 2, 2, 1], + strides=[1, 1, 1, 1], + rates=[1, 1, 1, 1], + padding='VALID') + + def construct(self, x): + return self.unfold(x) + + +class UnfoldNetSame(nn.Cell): + """ UnfoldNetSame definition """ + + def __init__(self): + super(UnfoldNetSame, self).__init__() + self.unfold = nn.Unfold(ksizes=[1, 2, 2, 1], + strides=[1, 1, 1, 1], + rates=[1, 1, 1, 1], + padding='SAME') + + def construct(self, x): + return self.unfold(x) + + test_cases = [ ('SoftMaxGrad', { 'block': SoftMaxGrad(VirtualNetWithLoss(P.Softmax())), @@ -440,6 +480,21 @@ test_cases = [ 'block': ComparisonNet(), 'desc_inputs': [Tensor(np.ones([6, 9, 10], np.int32)), Tensor(np.ones([6, 9, 10], np.int32))], }), + ('UnfoldValid', { + 'block': UnfoldNetValid(), + 'desc_inputs': [Tensor(np.ones([1, 1, 3, 3], np.float32))], + 'desc_bprop': [Tensor(np.ones([1, 4, 2, 2], np.float32))], + 'skip': ['backward']}), + ('UnfoldSame', { + 'block': UnfoldNetSame(), + 'desc_inputs': [Tensor(np.ones([1, 1, 3, 3], np.float32))], + 'desc_bprop': [Tensor(np.ones([1, 4, 3, 3], np.float32))], + 'skip': ['backward']}), + ('UnfoldGrad', { + 'block': GradWrapUnfold(UnfoldNetValid()), + 'desc_inputs': [Tensor(np.ones([1, 1, 3, 3], np.float32))], + 'desc_bprop': [Tensor(np.ones([1, 4, 2, 2], np.float32))], + 'skip': ['backward']}), ] test_cases_for_verify_exception = [ -- GitLab