提交 6721541c 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!25 Develop Cell unfold,and Op ExtractImagePatches.

Merge pull request !25 from zhangbuxue/unfold-develop
......@@ -45,12 +45,6 @@ class TbeAdapter {
std::vector<nlohmann::json> *input_list, kCreaterType creater_type);
private:
static void MaxPoolWithArgmaxAttrJsonPass(const AnfNodePtr &anf_node,
const std::vector<std::shared_ptr<OpAttr>> &op_info_attrs,
nlohmann::json *attrs_json);
static void MaxPoolGradWithArgmaxAttrJsonPass(const AnfNodePtr &anf_node,
const std::vector<std::shared_ptr<OpAttr>> &op_info_attrs,
nlohmann::json *attrs_json);
static void Conv2DAttrJsonPass(const AnfNodePtr &anf_node, const std::vector<std::shared_ptr<OpAttr>> &op_info_attrs,
nlohmann::json *attrs_json);
static void Conv2DBackpropFilterAttrJsonPass(const AnfNodePtr &anf_node,
......
......@@ -96,6 +96,7 @@ const char kNameConfusionMatrix[] = "ConfusionMatrix";
const char kNameResizeNearestNeighborD[] = "ResizeNearestNeighbor";
const char kNameResizeNearestNeighborGrad[] = "ResizeNearestNeighborGrad";
const char kNameApplyAdam[] = "Adam";
const char kNameExtractImagePatches[] = "ExtractImagePatches";
const char kNameReLU6[] = "ReLU6";
const char kNameReLU6Grad[] = "ReLU6Grad";
const char kNameElu[] = "Elu";
......@@ -214,6 +215,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{string(kNameMaxPoolGrad), ADPT_DESC(MaxPoolGrad)},
{string(kNameAvgPoolGrad), ADPT_DESC(AvgPoolGrad)},
{string(kNameMaxPoolGradWithArgmax), ADPT_DESC(MaxPoolGradWithArgmax)},
{string(kNameExtractImagePatches), ADPT_DESC(ExtractImagePatches)},
{prim::kPrimAssign->name(), ADPT_DESC(Assign)},
{prim::kPrimStateSetItem->name(), ADPT_DESC(Assign)},
{prim::kPrimReluGrad->name(), ADPT_DESC(ReluGrad)},
......
......@@ -322,18 +322,12 @@ class OpAdapter : public BaseOpAdapter {
Status UpdateSingleOutputDesc(const OperatorPtr& op, const abstract::BaseShapePtr& shp, const TypePtr& type) {
MS_EXCEPTION_IF_NULL(type);
TypeId me_type = type->type_id();
if (kObjectTypeTensorType == me_type) {
me_type = dyn_cast<TensorType>(type)->element()->type_id();
}
std::vector<int> shape;
auto normal_shape_ptr = dyn_cast<abstract::Shape>(shp);
if (nullptr != normal_shape_ptr) {
shape = normal_shape_ptr->shape();
std::string format = "NCHW";
if (op->GetOpType() == kExtractImagePatchesOpName) {
format = "NHWC";
}
auto desc = TransformUtil::GetGeTensorDesc(shape, me_type, "NCHW");
auto desc = CreateOutputDesc(dyn_cast<abstract::Shape>(shp), type, format);
if (desc == nullptr) {
MS_LOG(ERROR) << "Update output descriptor failed!";
return FAILED;
......@@ -410,14 +404,15 @@ class OpAdapter : public BaseOpAdapter {
MS_LOG(ERROR) << "output_map is not equal tuple_shape size";
return FAILED;
}
std::string format = "NCHW";
if (op->GetOpType() == kTopKOpName) {
format = "NHWC";
}
for (size_t i = 0; i < tuple_shp->shape().size(); ++i) {
auto tuple_type = dyn_cast<Tuple>(type);
MS_EXCEPTION_IF_NULL(tuple_type);
TypePtr type_elem = tuple_type->elements()[i];
std::string format = "NCHW";
if (op->GetOpType() == kTopKOpName) {
format = "NHWC";
}
auto desc = CreateOutputDesc(dyn_cast<abstract::Shape>(tuple_shp->shape()[i]), type_elem, format);
if (desc == nullptr) {
MS_LOG(ERROR) << "Create output descriptor failed!";
......@@ -476,6 +471,9 @@ class OpAdapter : public BaseOpAdapter {
if (desc == nullptr) {
continue;
}
if (op->GetOpType() == kExtractImagePatchesOpName) {
desc->SetFormat(ge::Format::FORMAT_NHWC);
}
it->second.update_input_desc(op, *desc);
}
}
......
......@@ -751,16 +751,20 @@ ATTR_MAP(MaxPoolWithArgmax) = {{"ksize", ATTR_DESC(ksize, AnyTraits<int>(), AnyT
OUTPUT_MAP(MaxPoolWithArgmax) = {{0, OUTPUT_DESC(y)}, {1, OUTPUT_DESC(argmax)}};
// MaxPoolGradWithArgmax
INPUT_MAP(MaxPoolGradWithArgmax) = {
{1, INPUT_DESC(x)},
{2, INPUT_DESC(grad)},
{3, INPUT_DESC(argmax)},
};
INPUT_MAP(MaxPoolGradWithArgmax) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(grad)}, {3, INPUT_DESC(argmax)}};
ATTR_MAP(MaxPoolGradWithArgmax) = {{"ksize", ATTR_DESC(ksize, AnyTraits<int>(), AnyTraits<std::vector<int64_t>>())},
{"strides", ATTR_DESC(strides, AnyTraits<int>(), AnyTraits<std::vector<int64_t>>())},
{"padding", ATTR_DESC(padding, AnyTraits<std::string>())}};
OUTPUT_MAP(MaxPoolGradWithArgmax) = {{0, OUTPUT_DESC(y)}};
// ExtractImagePatches
INPUT_MAP(ExtractImagePatches) = {{1, INPUT_DESC(images)}};
ATTR_MAP(ExtractImagePatches) = {{"ksizes", ATTR_DESC(ksizes, AnyTraits<int>(), AnyTraits<std::vector<int64_t>>())},
{"strides", ATTR_DESC(strides, AnyTraits<int>(), AnyTraits<std::vector<int64_t>>())},
{"rates", ATTR_DESC(rates, AnyTraits<int>(), AnyTraits<std::vector<int64_t>>())},
{"padding", ATTR_DESC(padding, AnyTraits<std::string>())}};
OUTPUT_MAP(ExtractImagePatches) = {{0, OUTPUT_DESC(y)}};
// Conv2D
INPUT_MAP(Conv2D) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(filter)}};
ATTR_MAP(Conv2D) = {
......
......@@ -95,6 +95,8 @@ DECLARE_OP_USE_OUTPUT(MaxPoolGradWithArgmax)
DECLARE_OP_ADAPTER(Conv2D)
DECLARE_OP_USE_ENUM(Conv2D)
DECLARE_OP_USE_OUTPUT(Conv2D)
DECLARE_OP_ADAPTER(ExtractImagePatches)
DECLARE_OP_USE_OUTPUT(ExtractImagePatches)
DECLARE_OP_ADAPTER(Conv2DBackpropInputD)
DECLARE_OP_USE_ENUM(Conv2DBackpropInputD)
DECLARE_OP_USE_INPUT_ATTR(Conv2DBackpropInputD)
......
......@@ -49,6 +49,7 @@ constexpr auto kBroadcastOpName = "Broadcast";
constexpr auto kReduceScatterOpName = "ReduceScatter";
constexpr auto kMemCpyAsyncOpName = "memcpy_async";
constexpr auto kTopKOpName = "TopK";
constexpr auto kExtractImagePatchesOpName = "ExtractImagePatches";
constexpr auto kBNTrainingReduceOpName = "BNTrainingReduce";
constexpr auto kBNTrainingUpdateOpName = "BNTrainingUpdate";
constexpr auto kSimpleMeanGradOpName = "SimpleMeanGrad";
......
......@@ -22,7 +22,7 @@ from .normalization import BatchNorm1d, BatchNorm2d, LayerNorm
from .container import SequentialCell, CellList
from .conv import Conv2d, Conv2dTranspose
from .lstm import LSTM
from .basic import Dropout, Flatten, Dense, ClipByNorm, Norm, OneHot, Pad
from .basic import Dropout, Flatten, Dense, ClipByNorm, Norm, OneHot, Pad, Unfold
from .embedding import Embedding
from .pooling import AvgPool2d, MaxPool2d
from .image import ImageGradients, SSIM
......@@ -35,6 +35,6 @@ __all__ = ['Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'Sigmoid',
'LSTM',
'Dropout', 'Flatten', 'Dense', 'ClipByNorm', 'Norm', 'OneHot',
'Embedding',
'AvgPool2d', 'MaxPool2d', 'Pad',
'AvgPool2d', 'MaxPool2d', 'Pad', 'Unfold',
'ImageGradients', 'SSIM',
]
......@@ -439,3 +439,51 @@ class Pad(Cell):
else:
x = self.pad(x, self.paddings)
return x
class Unfold(Cell):
"""
Extract patches from images.
The input tensor must be a 4-D tensor and the data format is NCHW.
Args:
ksizes (Union[tuple[int], list[int]]): The size of sliding window, should be a tuple or list of int,
and the format is [1, ksize_row, ksize_col, 1].
strides (Union[tuple[int], list[int]]): Distance between the centers of the two consecutive patches,
should be a tuple or list of int, and the format is [1, stride_row, stride_col, 1].
rates (Union[tuple[int], list[int]]): In each extracted patch, the gap between the corresponding dim
pixel positions, should be a tuple or list of int, and the format is [1, rate_row, rate_col, 1].
padding (str): The type of padding algorithm, is a string whose value is "same" or "valid",
not case sensitive. Default: "valid".
- same: Means that the patch can take the part beyond the original image, and this part is filled with 0.
- valid: Means that the patch area taken must be completely contained in the original image.
Inputs:
- **input_x** (Tensor) - A 4-D tensor whose shape is [in_batch, in_depth, in_row, in_col] and
data type is int8, float16, uint8.
Outputs:
Tensor, a 4-D tensor whose data type is same as 'input_x',
and the shape is [out_batch, out_depth, out_row, out_col], the out_batch is same as the in_batch.
Examples:
>>> net = Unfold(ksizes=[1, 2, 2, 1], strides=[1, 1, 1, 1], rates=[1, 1, 1, 1])
>>> image = Tensor(np.ones([1, 1, 3, 3]), dtype=mstype.float16)
>>> net(image)
Tensor ([[[[1, 1] [1, 1]] [[1, 1], [1, 1]] [[1, 1] [1, 1]], [[1, 1], [1, 1]]]],
shape=(1, 4, 2, 2), dtype=mstype.float16)
"""
def __init__(self, ksizes, strides, rates, padding="valid"):
super(Unfold, self).__init__()
self.extract_image_patches = P.ExtractImagePatches(ksizes, strides, rates, padding)
self.transpose = P.Transpose()
self.format_NHWC = (0, 2, 3, 1)
self.format_NCHW = (0, 3, 1, 2)
def construct(self, input_x):
x_transpose = self.transpose(input_x, self.format_NHWC)
ret = self.extract_image_patches(x_transpose)
ret_transpose = self.transpose(ret, self.format_NCHW)
return ret_transpose
......@@ -14,7 +14,7 @@
# ============================================================================
"""Define the grad rules of neural network related operations."""
from mindspore.common import dtype as mstype
from .. import functional as F
from .. import operations as P
from ..operations import _grad_ops as G
......@@ -52,6 +52,61 @@ def get_bprop_conv2d(self):
return bprop
@bprop_getters.register(P.ExtractImagePatches)
def get_bprop_extract_image_patches(self):
"""Grad definition for `ExtractImagePatches` operation."""
get_shape = P.Shape()
reshape = P.Reshape()
extract_image_patches = P.ExtractImagePatches(ksizes=self.ksizes,
strides=self.strides,
rates=self.rates,
padding=self.padding)
concat = P.Concat(axis=-1)
expand_dims = P.ExpandDims()
scatter_nd = P.ScatterNd()
dtype = P.DType()
fill = P.Fill()
slice_op = P.Slice()
transpose = P.Transpose()
matmul = P.MatMul()
cast = P.Cast()
_, ksizes_row, ksizes_col, _ = self.ksizes
def bprop(x, out, dout):
x_shape = get_shape(x)
x_batch, x_row, x_col, x_depth = x_shape
x_indices_num = x_row * x_col + 1
x_idx = F.tuple_to_array(range(1, x_indices_num))
x_idx = reshape(x_idx, (1, x_row, x_col, 1))
x_idx = cast(x_idx, mstype.float16)
x_idx_patch = extract_image_patches(x_idx)
x_idx_patch = transpose(x_idx_patch, (0, 3, 1, 2))
x_idx_patch = cast(x_idx_patch, mstype.int32)
out_shape = get_shape(out)
_, out_row, out_col, _ = out_shape
out_indices_num = out_row * out_col * ksizes_row * ksizes_col
out_idx = F.tuple_to_array(range(out_indices_num))
out_idx = reshape(out_idx, (1, ksizes_row * ksizes_col, out_row, out_col))
idx_tensor = concat((expand_dims(x_idx_patch, -1), expand_dims(out_idx, -1)))
idx_tensor = reshape(idx_tensor, (-1, 2))
sp_shape = (x_indices_num, out_indices_num)
sp_tensor = scatter_nd(idx_tensor, fill(dtype(dout), (out_indices_num,), 1), sp_shape)
sp_tensor = slice_op(sp_tensor, (1, 0), (x_indices_num - 1, out_indices_num))
grad = reshape(dout, (x_batch, out_row, out_col, ksizes_row, ksizes_col, x_depth))
grad = transpose(grad, (1, 2, 3, 4, 0, 5))
grad = reshape(grad, (-1, x_batch * x_depth))
jac = matmul(sp_tensor, grad)
dx = reshape(jac, (x_row, x_col, x_batch, x_depth))
dx = transpose(dx, (2, 0, 1, 3))
return (dx,)
return bprop
@bprop_getters.register(P.DepthwiseConv2dNative)
def get_bprop_depthwise_conv2d_native(self):
"""Grad definition for `DepthwiseConv2dNative` operation."""
......
......@@ -57,7 +57,7 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm,
Gelu, Elu,
GetNext, L2Normalize, LayerNorm,
LogSoftmax,
MaxPool,
MaxPool, ExtractImagePatches,
AvgPool, Conv2DBackpropInput,
MaxPoolWithArgmax, OneHot, Pad, MirrorPad, PReLU, ReLU, ReLU6, HSwish, HSigmoid,
ResizeBilinear, Sigmoid,
......@@ -89,6 +89,7 @@ __all__ = [
'Sqrt',
'Square',
'Conv2D',
'ExtractImagePatches',
'Flatten',
'MaxPoolWithArgmax',
'FusedBatchNorm',
......
......@@ -1475,7 +1475,7 @@ class LogicalNot(PrimitiveWithInfer):
Computes the "logical NOT" of a tensor element-wise.
Inputs:
- **input_x** (Tensor) - The input tensor whose dtype is bool
- **input_x** (Tensor) - The input tensor whose dtype is bool.
Outputs:
Tensor, the shape is same as the `input_x`, and the dtype is bool.
......
......@@ -2550,6 +2550,7 @@ class ApplyFtrl(PrimitiveWithInfer):
Outputs:
Tensor, representing the updated var.
"""
@prim_attr_register
def __init__(self, use_locking=False):
self.init_prim_io_names(inputs=['var', 'accum', 'linear', 'grad', 'lr', 'l1', 'l2', 'lr_power'],
......@@ -2570,8 +2571,99 @@ class ApplyFtrl(PrimitiveWithInfer):
args = {'var_type': var_type, 'accum_type': accum_type, 'linear_type': linear_type, 'grad_type': grad_type}
validator.check_type_same(args, (mstype.float32, mstype.float16))
validator.check_typename("lr", lr_type,[mstype.float16, mstype.float32])
validator.check_typename("l1", l1_type,[mstype.float16, mstype.float32])
validator.check_typename("l2", l2_type,[mstype.float16, mstype.float32])
validator.check_typename("lr_power", lr_power_type,[mstype.float16, mstype.float32])
validator.check_typename("lr", lr_type, [mstype.float16, mstype.float32])
validator.check_typename("l1", l1_type, [mstype.float16, mstype.float32])
validator.check_typename("l2", l2_type, [mstype.float16, mstype.float32])
validator.check_typename("lr_power", lr_power_type, [mstype.float16, mstype.float32])
return var_type
class ExtractImagePatches(PrimitiveWithInfer):
"""
Extract patches from images.
The input tensor must be a 4-D tensor and the data format is NHWC.
Args:
ksizes (Union[tuple[int], list[int]]): The size of sliding window, should be a tuple or list of int,
and the format is [1, ksize_row, ksize_col, 1].
strides (Union[tuple[int], list[int]]): Distance between the centers of the two consecutive patches,
should be a tuple or list of int, and the format is [1, stride_row, stride_col, 1].
rates (Union[tuple[int], list[int]]): In each extracted patch, the gap between the corresponding dim
pixel positions, should be a tuple or list of int, and the format is [1, rate_row, rate_col, 1].
padding (str): The type of padding algorithm, is a string whose value is "same" or "valid",
not case sensitive. Default: "valid".
- same: Means that the patch can take the part beyond the original image, and this part is filled with 0.
- valid: Means that the patch area taken must be completely contained in the original image.
Inputs:
- **input_x** (Tensor) - A 4-D tensor whose shape is [in_batch, in_row, in_col, in_depth] and
data type is int8, float16, uint8.
Outputs:
Tensor, a 4-D tensor whose data type is same as 'input_x',
and the shape is [out_batch, out_row, out_col, out_depth], the out_batch is same as the in_batch.
"""
@prim_attr_register
def __init__(self, ksizes, strides, rates, padding="valid"):
"""init"""
validator.check_type("ksizes", ksizes, [tuple, list])
validator.check_type("strides", strides, [tuple, list])
validator.check_type("rates", rates, [tuple, list])
self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME'])
self.add_prim_attr("padding", self.padding)
if len(ksizes) != 4 or ksizes[0] != 1 or ksizes[3] != 1:
raise ValueError("The format of ksizes should be [1, ksize_row, ksize_col, 1], "
f"but got {ksizes}.")
if not isinstance(ksizes[1], int) or not isinstance(ksizes[2], int) or \
ksizes[1] < 1 or ksizes[2] < 1:
raise ValueError("The ksize_row and ksize_col in ksizes should be an positive integer number, "
f"but got ksize_row is {ksizes[1]}, ksize_col is {ksizes[2]}")
if len(strides) != 4 or strides[0] != 1 or strides[3] != 1:
raise ValueError("The format of strides should be [1, stride_row, stride_col, 1], "
f"but got {strides}.")
if not isinstance(strides[1], int) or not isinstance(strides[2], int) or \
strides[1] < 1 or strides[2] < 1:
raise ValueError("The stride_row and stride_col in strides should be an positive integer number, "
f"but got stride_row is {strides[1]}, stride_col is {strides[2]}")
if len(rates) != 4 or rates[0] != 1 or rates[3] != 1:
raise ValueError("The format of rates should be [1, rate_row, rate_col, 1], "
f"but got {rates}.")
if not isinstance(rates[1], int) or not isinstance(rates[2], int) or \
rates[1] < 1 or rates[2] < 1:
raise ValueError("The rate_row and rate_col in rates should be an positive integer number, "
f"but got rate_row is {rates[1]}, rate_col is {rates[2]}")
def infer_shape(self, input_x):
in_batch, in_row, in_col, in_depth = input_x
_, ksize_row, ksize_col, _ = self.ksizes
_, stride_row, stride_col, _ = self.strides
_, rate_row, rate_col, _ = self.rates
if len(input_x) != 4:
raise ValueError("The `input_x` should be a 4-D tensor, "
f"but got a {len(input_x)}-D tensor whose shape is {input_x}")
out_batch = in_batch
out_depth = ksize_row * ksize_col * in_depth
if self.padding == "VALID":
out_row = \
(in_row - (ksize_row + (ksize_row - 1) * (rate_row - 1))) // stride_row + 1
out_col = \
(in_col - (ksize_col + (ksize_col - 1) * (rate_col - 1))) // stride_col + 1
else:
out_row = (in_row - 1) // stride_row + 1
out_col = (in_col - 1) // stride_col + 1
out_shape = [out_batch, out_row, out_col, out_depth]
return out_shape
def infer_dtype(self, input_x):
validator.check_subclass("input_x", input_x, mstype.tensor)
validator.check_typename("input_x_dtype", input_x, (mstype.int8, mstype.float16, mstype.float32))
return input_x
......@@ -30,6 +30,8 @@ from ....mindspore_test_framework.pipeline.forward.compile_forward \
import pipeline_for_compile_forward_ge_graph_for_case_by_case_config
from ....mindspore_test_framework.pipeline.forward.verify_exception \
import pipeline_for_verify_exception_for_case_by_case_config
# pylint: disable=W0613
# pylint: disable=W0231
# W0613: unused-argument
......@@ -106,7 +108,7 @@ def test_realdiv():
result = div(x, y)
x = x.asnumpy()
y = y.asnumpy()
expect = x/y
expect = x / y
assert np.all(result.asnumpy() == expect)
......@@ -122,6 +124,7 @@ def test_eye():
class VirtualLossGrad(PrimitiveWithInfer):
""" VirtualLossGrad definition """
@prim_attr_register
def __init__(self):
"""init VirtualLossGrad"""
......@@ -138,6 +141,7 @@ class VirtualLossGrad(PrimitiveWithInfer):
class VirtualLoss(PrimitiveWithInfer):
""" VirtualLoss definition """
@prim_attr_register
def __init__(self):
"""init VirtualLoss"""
......@@ -151,6 +155,7 @@ class VirtualLoss(PrimitiveWithInfer):
def bprop(x, out, dout):
dx = loss_grad(x, out, dout)
return (dx,)
return bprop
def infer_shape(self, x_shape):
......@@ -162,6 +167,7 @@ class VirtualLoss(PrimitiveWithInfer):
class NetWithLoss(nn.Cell):
""" NetWithLoss definition """
def __init__(self, network):
super(NetWithLoss, self).__init__()
self.loss = VirtualLoss()
......@@ -174,6 +180,7 @@ class NetWithLoss(nn.Cell):
class GradWrap(nn.Cell):
""" GradWrap definition """
def __init__(self, network):
super(GradWrap, self).__init__()
self.network = network
......@@ -184,6 +191,7 @@ class GradWrap(nn.Cell):
class MatMulNet(nn.Cell):
""" MatMulNet definition """
def __init__(self):
super(MatMulNet, self).__init__()
self.matmul = P.MatMul()
......@@ -195,6 +203,7 @@ class MatMulNet(nn.Cell):
class NetWithLossSub(nn.Cell):
""" NetWithLossSub definition """
def __init__(self, network):
super(NetWithLossSub, self).__init__()
self.loss = VirtualLoss()
......@@ -207,6 +216,7 @@ class NetWithLossSub(nn.Cell):
class GradWrapSub(nn.Cell):
""" GradWrapSub definition """
def __init__(self, network):
super(GradWrapSub, self).__init__()
self.network = network
......@@ -217,6 +227,7 @@ class GradWrapSub(nn.Cell):
class SubNet(nn.Cell):
""" SubNet definition """
def __init__(self):
super(SubNet, self).__init__()
self.sub = P.Sub()
......@@ -227,6 +238,7 @@ class SubNet(nn.Cell):
class NpuFloatNet(nn.Cell):
""" NpuFloat definition """
def __init__(self):
super(NpuFloatNet, self).__init__()
self.mul = P.Mul()
......@@ -258,6 +270,7 @@ class NpuFloatNet(nn.Cell):
class DiagNet(nn.Cell):
""" DiagNet definition """
def __init__(self):
super(DiagNet, self).__init__()
self.fill = P.Fill()
......@@ -269,6 +282,7 @@ class DiagNet(nn.Cell):
class NetWithLossCumSum(nn.Cell):
""" NetWithLossCumSum definition """
def __init__(self, network):
super(NetWithLossCumSum, self).__init__()
self.loss = VirtualLoss()
......@@ -281,6 +295,7 @@ class NetWithLossCumSum(nn.Cell):
class GradWrapCumSum(nn.Cell):
""" GradWrap definition """
def __init__(self, network):
super(GradWrapCumSum, self).__init__()
self.network = network
......@@ -291,6 +306,7 @@ class GradWrapCumSum(nn.Cell):
class NetCumSum(nn.Cell):
""" NetCumSum definition """
def __init__(self):
super(NetCumSum, self).__init__()
self.cumsum = P.CumSum()
......@@ -321,8 +337,8 @@ test_case_math_ops = [
'skip': ['backward']}),
('CumSumGrad', {
'block': GradWrapCumSum(NetWithLossCumSum(NetCumSum())),
'desc_inputs': [Tensor(np.array([[3, 4, 6, 10],[1, 6, 7, 9],[4, 3, 8, 7],[1, 3, 7, 9]]).astype(np.float16))],
'desc_bprop': [Tensor(np.array([[3, 4, 6, 10],[1, 6, 7, 9],[4, 3, 8, 7],[1, 3, 7, 9]]).astype(np.float16))],
'desc_inputs': [Tensor(np.array([[3, 4, 6, 10], [1, 6, 7, 9], [4, 3, 8, 7], [1, 3, 7, 9]]).astype(np.float16))],
'desc_bprop': [Tensor(np.array([[3, 4, 6, 10], [1, 6, 7, 9], [4, 3, 8, 7], [1, 3, 7, 9]]).astype(np.float16))],
'skip': ['backward']}),
('Diag', {
'block': DiagNet(),
......@@ -351,7 +367,6 @@ test_case_math_ops = [
'skip': ['backward']}),
]
test_case_lists = [test_case_math_ops]
test_exec_case = functools.reduce(lambda x, y: x + y, test_case_lists)
# use -k to select certain testcast
......@@ -360,6 +375,7 @@ test_exec_case = functools.reduce(lambda x, y: x + y, test_case_lists)
import mindspore.context as context
@non_graph_engine
@mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config)
def test_exec():
......@@ -369,16 +385,16 @@ def test_exec():
raise_set = [
('StridedSlice_1_Error', {
'block': (lambda x : P.StridedSlice(begin_mask="1"), {'exception': ValueError}),
'block': (lambda x: P.StridedSlice(begin_mask="1"), {'exception': ValueError}),
'desc_inputs': [0]}),
('StridedSlice_2_Error', {
'block': (lambda x : P.StridedSlice(end_mask="1"), {'exception': ValueError}),
'block': (lambda x: P.StridedSlice(end_mask="1"), {'exception': ValueError}),
'desc_inputs': [0]}),
('StridedSlice_3_Error', {
'block': (lambda x : P.StridedSlice(ellipsis_mask=1.1), {'exception': ValueError}),
'block': (lambda x: P.StridedSlice(ellipsis_mask=1.1), {'exception': ValueError}),
'desc_inputs': [0]}),
('StridedSlice_4_Error', {
'block': (lambda x : P.StridedSlice(new_axis_mask="1.1"), {'exception': ValueError}),
'block': (lambda x: P.StridedSlice(new_axis_mask="1.1"), {'exception': ValueError}),
'desc_inputs': [0]}),
]
......
......@@ -382,6 +382,46 @@ def test_max_pool_with_arg_max():
print(ret)
class GradWrapUnfold(nn.Cell):
""" GradWrapUnfold definition """
def __init__(self, network):
super(GradWrapUnfold, self).__init__()
self.network = network
self.sens = Tensor(np.ones([1, 4, 2, 2], np.float32))
def construct(self, x):
return C.grad_all_with_sens(self.network)(x, self.sens)
class UnfoldNetValid(nn.Cell):
""" UnfoldNetValid definition """
def __init__(self):
super(UnfoldNetValid, self).__init__()
self.unfold = nn.Unfold(ksizes=[1, 2, 2, 1],
strides=[1, 1, 1, 1],
rates=[1, 1, 1, 1],
padding='VALID')
def construct(self, x):
return self.unfold(x)
class UnfoldNetSame(nn.Cell):
""" UnfoldNetSame definition """
def __init__(self):
super(UnfoldNetSame, self).__init__()
self.unfold = nn.Unfold(ksizes=[1, 2, 2, 1],
strides=[1, 1, 1, 1],
rates=[1, 1, 1, 1],
padding='SAME')
def construct(self, x):
return self.unfold(x)
test_cases = [
('SoftMaxGrad', {
'block': SoftMaxGrad(VirtualNetWithLoss(P.Softmax())),
......@@ -440,6 +480,21 @@ test_cases = [
'block': ComparisonNet(),
'desc_inputs': [Tensor(np.ones([6, 9, 10], np.int32)), Tensor(np.ones([6, 9, 10], np.int32))],
}),
('UnfoldValid', {
'block': UnfoldNetValid(),
'desc_inputs': [Tensor(np.ones([1, 1, 3, 3], np.float32))],
'desc_bprop': [Tensor(np.ones([1, 4, 2, 2], np.float32))],
'skip': ['backward']}),
('UnfoldSame', {
'block': UnfoldNetSame(),
'desc_inputs': [Tensor(np.ones([1, 1, 3, 3], np.float32))],
'desc_bprop': [Tensor(np.ones([1, 4, 3, 3], np.float32))],
'skip': ['backward']}),
('UnfoldGrad', {
'block': GradWrapUnfold(UnfoldNetValid()),
'desc_inputs': [Tensor(np.ones([1, 1, 3, 3], np.float32))],
'desc_bprop': [Tensor(np.ones([1, 4, 2, 2], np.float32))],
'skip': ['backward']}),
]
test_cases_for_verify_exception = [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册