提交 0f89cc1d 编写于 作者: B buxue

dock AcoshGrad for GE and AvgPool AvgPoolGrad for Vm

上级 a62c3e5c
......@@ -38,6 +38,7 @@ static std::map<string, string> tbe_func_adapter_map = {
{"reduce_mean", "reduce_mean_d"},
{"reduce_max", "reduce_max_d"},
{"reduce_min", "reduce_min_d"},
{"avg_pool_grad", "avg_pool_grad_d"},
{"conv2d_backprop_filter", "conv2d_backprop_filter_d"},
{"conv2d_backprop_input", "conv2d_backprop_input_d"},
{"depthwise_conv2d_native", "depthwise_conv2d"},
......
......@@ -170,6 +170,7 @@ const PrimitivePtr kPrimPooling = std::make_shared<Primitive>("Pooling");
const PrimitivePtr kPrimPoolingGrad = std::make_shared<Primitive>("PoolingGrad");
const PrimitivePtr kPrimMaxPool = std::make_shared<Primitive>("MaxPool");
const PrimitivePtr kPrimMaxPoolGrad = std::make_shared<Primitive>("MaxPoolGrad");
const PrimitivePtr kPrimAvgPoolGrad = std::make_shared<Primitive>("AvgPoolGrad");
const PrimitivePtr kPrimFusedBatchNorm = std::make_shared<Primitive>("FusedBatchNorm");
const PrimitivePtr kPrimConv2D = std::make_shared<Primitive>("Conv2D");
const PrimitivePtr kPrimFusedBatchNormGrad = std::make_shared<Primitive>("FusedBatchNormGrad");
......
......@@ -178,6 +178,7 @@ extern const PrimitivePtr kPrimFusedBatchNorm;
extern const PrimitivePtr kPrimConv2D;
extern const PrimitivePtr kPrimMaxPool;
extern const PrimitivePtr kPrimMaxPoolGrad;
extern const PrimitivePtr kPrimAvgPoolGrad;
extern const PrimitivePtr kPrimFusedBatchNormGrad;
extern const PrimitivePtr kPrimReluGrad;
extern const PrimitivePtr kPrimConv2DBackpropInput;
......
......@@ -25,6 +25,7 @@ namespace mindspore {
namespace opt {
ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() {
Register(prim::kPrimCast->name(), {1});
Register(prim::kPrimAvgPoolGrad->name(), {0});
Register(prim::kPrimConv2DBackpropInput->name(), {2});
Register(prim::kPrimConv2DBackpropFilter->name(), {2});
Register(prim::kPrimDepthwiseConv2dNativeBackpropFilter->name(), {1});
......
......@@ -178,6 +178,7 @@ const char kNameBinaryCrossEntropy[] = "BinaryCrossEntropy";
const char kNameBinaryCrossEntropyGrad[] = "BinaryCrossEntropyGrad";
const char kNameSparseApplyAdagrad[] = "SparseApplyAdagrad";
const char kNameAcosh[] = "Acosh";
const char kNameAcoshGrad[] = "AcoshGrad";
const char kNameFloorMod[] = "FloorMod";
const char kNameSpaceToDepth[] = "SpaceToDepth";
const char kNameDepthToSpace[] = "DepthToSpace";
......@@ -375,6 +376,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{string(kNameBinaryCrossEntropyGrad), ADPT_DESC(BinaryCrossEntropyGrad)},
{string(kNameSparseApplyAdagrad), ADPT_DESC(SparseApplyAdagradD)},
{string(kNameAcosh), ADPT_DESC(Acosh)},
{string(kNameAcoshGrad), ADPT_DESC(AcoshGrad)},
{string(kNameFloorMod), ADPT_DESC(FloorMod)},
{string(kNameSpaceToDepth), ADPT_DESC(SpaceToDepth)},
{string(kNameDepthToSpace), ADPT_DESC(DepthToSpace)},
......
......@@ -357,6 +357,11 @@ INPUT_MAP(Acosh) = {{1, INPUT_DESC(x)}};
ATTR_MAP(Acosh) = EMPTY_ATTR_MAP;
OUTPUT_MAP(Acosh) = {{0, OUTPUT_DESC(y)}};
// AcoshGrad
INPUT_MAP(AcoshGrad) = {{1, INPUT_DESC(y)}, {2, INPUT_DESC(dy)}};
ATTR_MAP(AcoshGrad) = EMPTY_ATTR_MAP;
OUTPUT_MAP(AcoshGrad) = {{0, OUTPUT_DESC(z)}};
// Floor
INPUT_MAP(Floor) = {{1, INPUT_DESC(x)}};
ATTR_MAP(Floor) = EMPTY_ATTR_MAP;
......
......@@ -327,13 +327,15 @@ DECLARE_OP_ADAPTER(Const)
DECLARE_OP_USE_OUTPUT(Const)
DECLARE_OP_ADAPTER(Cos)
DECLARE_OP_USE_OUTPUT(Cos)
DECLARE_OP_ADAPTER(Acos)
DECLARE_OP_USE_OUTPUT(Acos)
DECLARE_OP_ADAPTER(AcosGrad)
DECLARE_OP_USE_OUTPUT(AcosGrad)
DECLARE_OP_ADAPTER(Acosh)
DECLARE_OP_USE_OUTPUT(Acosh)
DECLARE_OP_ADAPTER(AcoshGrad)
DECLARE_OP_USE_OUTPUT(AcoshGrad)
DECLARE_OP_ADAPTER(Floor)
DECLARE_OP_USE_OUTPUT(Floor)
......
......@@ -21,6 +21,7 @@ from mindspore._checkparam import check_int_positive, check_bool
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops.functional import identity
from mindspore.ops.operations import _inner_ops as inner
from mindspore.common.parameter import Parameter
from mindspore._extends import cell_attr_register
from mindspore.common.api import ms_function
......@@ -480,7 +481,7 @@ class Unfold(Cell):
"""
def __init__(self, ksizes, strides, rates, padding="valid"):
super(Unfold, self).__init__()
self.extract_image_patches = P.ExtractImagePatches(ksizes, strides, rates, padding)
self.extract_image_patches = inner.ExtractImagePatches(ksizes, strides, rates, padding)
self.transpose = P.Transpose()
self.format_NHWC = (0, 2, 3, 1)
self.format_NCHW = (0, 3, 1, 2)
......
......@@ -18,6 +18,7 @@ from mindspore.common import dtype as mstype
from .. import functional as F
from .. import operations as P
from ..operations import _grad_ops as G
from ..operations import _inner_ops as inner
from ..composite.multitype_ops.zeros_like_impl import zeros_like
from .grad_base import bprop_getters
......@@ -29,6 +30,7 @@ def get_bprop_bias_add(self):
def bprop(x, w, out, dout):
return dout, bias_grad(dout)
return bprop
......@@ -49,18 +51,19 @@ def get_bprop_conv2d(self):
dx = input_grad(dout, w, get_shape(x))
dw = filter_grad(dout, x, get_shape(w))
return dx, dw
return bprop
@bprop_getters.register(P.ExtractImagePatches)
@bprop_getters.register(inner.ExtractImagePatches)
def get_bprop_extract_image_patches(self):
"""Grad definition for `ExtractImagePatches` operation."""
get_shape = P.Shape()
reshape = P.Reshape()
extract_image_patches = P.ExtractImagePatches(ksizes=self.ksizes,
strides=self.strides,
rates=self.rates,
padding=self.padding)
extract_image_patches = inner.ExtractImagePatches(ksizes=self.ksizes,
strides=self.strides,
rates=self.rates,
padding=self.padding)
concat = P.Concat(axis=-1)
expand_dims = P.ExpandDims()
scatter_nd = P.ScatterNd()
......@@ -104,6 +107,7 @@ def get_bprop_extract_image_patches(self):
dx = transpose(dx, (2, 0, 1, 3))
return (dx,)
return bprop
......@@ -124,6 +128,7 @@ def get_bprop_depthwise_conv2d_native(self):
dx = input_grad(get_shape(x), w, dout)
dw = filter_grad(x, get_shape(w), dout)
return dx, dw
return bprop
......@@ -133,11 +138,12 @@ def get_bprop_max_pool_with_argmax(self):
maxpool_grad = G.MaxPoolGradWithArgmax(
ksize=self.ksize,
strides=self.strides,
padding=self.padding,)
padding=self.padding)
def bprop(x, out, dout):
dx = maxpool_grad(x, dout[0], out[1])
return (dx,)
return bprop
......@@ -152,6 +158,7 @@ def get_bprop_max_pool_grad(self):
def bprop(x, out, dout):
dx = maxpool_grad(x, out, dout)
return (dx,)
return bprop
......@@ -192,6 +199,7 @@ def get_bprop_dropout_gen_mask(self):
def bprop(shape, keep_prob, out, dout):
return (zeros_like(shape), zeros_like(keep_prob))
return bprop
......@@ -202,6 +210,7 @@ def get_bprop_dropout_do_mask(self):
def bprop(x, y, keep_prob, out, dout):
return (do_mask(dout, y, keep_prob), zeros_like(y), zeros_like(keep_prob))
return bprop
......@@ -213,6 +222,7 @@ def get_bprop_relu(self):
def bprop(x, out, dout):
dx = input_grad(dout, out)
return (dx,)
return bprop
......@@ -224,6 +234,7 @@ def get_bprop_relu6(self):
def bprop(x, out, dout):
dx = input_grad(dout, x)
return (dx,)
return bprop
......@@ -236,6 +247,7 @@ def get_bprop_relu_v2(self):
mask = out[1]
dx = input_grad(dout[0], mask)
return (dx,)
return bprop
......@@ -247,6 +259,7 @@ def get_bprop_hswish(self):
def bprop(x, out, dout):
dx = input_grad(dout, x)
return (dx,)
return bprop
......@@ -258,6 +271,7 @@ def get_bprop_hsigmoid(self):
def bprop(x, out, dout):
dx = input_grad(dout, x)
return (dx,)
return bprop
......@@ -269,6 +283,7 @@ def get_bprop_elu(self):
def bprop(x, out, dout):
dx = input_grad(dout, x)
return (dx,)
return bprop
......@@ -280,6 +295,7 @@ def get_bprop_sigmoid(self):
def bprop(x, out, dout):
dx = input_grad(out, dout)
return (dx,)
return bprop
......@@ -294,6 +310,7 @@ def get_bprop_softmax(self):
def bprop(x, out, dout):
dx = mul(sub(dout, sum_func(mul(dout, out), axis)), out)
return (dx,)
return bprop
......@@ -305,6 +322,7 @@ def get_bprop_log_softmax(self):
def bprop(x, out, dout):
dx = logsoftmax_grad(out, dout)
return (dx,)
return bprop
......@@ -316,6 +334,7 @@ def get_bprop_tanh(self):
def bprop(x, out, dout):
dx = logsoftmax_grad(out, dout)
return (dx,)
return bprop
......@@ -327,6 +346,7 @@ def get_bprop_gelu(self):
def bprop(x, out, dout):
dx = input_grad(dout, x, out)
return (dx,)
return bprop
......@@ -343,6 +363,7 @@ def get_bprop_fused_batch_norm(self):
dscale = out[1]
dbias = out[2]
return dx, dscale, dbias, zeros_like(mean), zeros_like(variance)
return bprop
......@@ -366,6 +387,7 @@ def get_bprop_batch_norm(self):
dscale = out[1]
dbias = out[2]
return dx, dscale, dbias, zeros_like(mean), zeros_like(variance)
return bprop
......@@ -377,6 +399,7 @@ def get_bprop_layer_norm(self):
def bprop(x, gamma, beta, out, dout):
dx, d_gamma, d_beta = layer_norm_grad(x, dout[0], out[2], out[1], gamma)
return dx, d_gamma, d_beta
return bprop
......@@ -388,6 +411,7 @@ def get_bprop_l2normalize(self):
def bprop(x, out, dout):
dx = input_grad(x, out, dout)
return (dx,)
return bprop
......@@ -400,6 +424,7 @@ def get_bprop_softmax_cross_entropy_with_logits(self):
grad = out[1]
grad = grad * expand(dout[0], -1)
return grad, zeros_like(labels)
return bprop
......@@ -417,6 +442,7 @@ def get_bprop_sparse_softmax_cross_entropy_with_logits(self):
grad = F.depend(grad, out)
grad = grad * dout
return grad, zeros_like(labels)
return bprop
......@@ -428,6 +454,7 @@ def get_bprop_resize_bilinear(self):
def bprop(x, out, dout):
dx = resize_grad(dout, x)
return (dx,)
return bprop
......@@ -437,6 +464,7 @@ def get_bprop_onehot(self):
def bprop(indices, depth, on_value, off_value, out, dout):
return zeros_like(indices), zeros_like(depth), zeros_like(on_value), zeros_like(off_value)
return bprop
......@@ -453,6 +481,7 @@ def get_bprop_top_kv2(self):
updates = dout[0]
shapes = shape_op(input_x)
return scatter(indices, updates, shapes), zeros_like(k)
return bprop
......@@ -518,6 +547,7 @@ def get_bprop_lstm(self):
dx, dhx, dcx = lstm_grad_data(y, dy, dhy, dcy, w, hx, cx, reserve, state)
dw = lstm_grad_weight(F.depend(x, dx), hx, y, reserve, state)
return dx, dhx, dcx, dw
return bprop
......@@ -529,6 +559,7 @@ def get_bprop_sigmoid_crossentropy_with_logits(self):
def bprop(x, y, out, dout):
dx = op(x, y, dout)
return (dx, zeros_like(y))
return bprop
......@@ -545,6 +576,7 @@ def get_bprop_pad(self):
shp = shape_op(x)
dx = P.Slice()(dout, begin, shp)
return (dx,)
return bprop
......@@ -556,6 +588,7 @@ def get_bprop_mirror_pad(self):
def bprop(x, paddings, out, dout):
dx = mirror_pad_grad(dout, paddings, x)
return (dx, zeros_like(paddings))
return bprop
......
......@@ -151,3 +151,5 @@ from .greater_equal import _greater_equal_tbe
from .not_equal import _not_equal_tbe
from .floor_mod import _floor_mod_tbe
from .scatter_nd_update import _scatter_nd_update_tbe
from .avg_pool import _avg_pool_tbe
from .avg_pool_grad import _avg_pool_grad_tbe
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""AvgPool op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
avg_pool_op_info = TBERegOp("AvgPool") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("avg_pool.so") \
.compute_cost(10) \
.kernel_name("avg_pool") \
.partial_flag(True) \
.attr("ksize", "required", "listInt", "all") \
.attr("strides", "required", "listInt", "all") \
.attr("padding", "required", "str", "all") \
.attr("data_format", "optional", "str", "all") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.get_op_info()
@op_info_register(avg_pool_op_info)
def _avg_pool_tbe():
"""AvgPool TBE register"""
return
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""AvgPoolGrad op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
avg_pool_grad_op_info = TBERegOp("AvgPoolGrad") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("avg_pool_grad_d.so") \
.compute_cost(10) \
.kernel_name("avg_pool_grad_d") \
.partial_flag(True) \
.attr("x_origin", "required", "listInt", "all") \
.attr("ksize", "required", "listInt", "all") \
.attr("strides", "required", "listInt", "all") \
.attr("padding", "required", "str", "all") \
.attr("data_format", "optional", "str", "all") \
.input(0, "input_grad", False, "required", "all") \
.input(1, "mean_matrix", False, "optional", "all") \
.input(2, "kernel_matrix", False, "optional", "all") \
.output(0, "out_grad", True, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_C1HWNCoC0, DataType.F16_5HD) \
.get_op_info()
@op_info_register(avg_pool_grad_op_info)
def _avg_pool_grad_tbe():
"""AvgPoolGrad TBE register"""
return
......@@ -57,7 +57,7 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm,
Gelu, Elu,
GetNext, L2Normalize, LayerNorm, L2Loss,
LogSoftmax,
MaxPool, ExtractImagePatches,
MaxPool,
AvgPool, Conv2DBackpropInput, ConfusionMulGrad,
MaxPoolWithArgmax, OneHot, Pad, MirrorPad, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid,
ResizeBilinear, Sigmoid,
......@@ -89,7 +89,6 @@ __all__ = [
'Sqrt',
'Square',
'Conv2D',
'ExtractImagePatches',
'Flatten',
'MaxPoolWithArgmax',
'FusedBatchNorm',
......
......@@ -59,6 +59,23 @@ class ACosGrad(PrimitiveWithInfer):
return x
class AcoshGrad(PrimitiveWithInfer):
"""Performs grad of Acosh operation."""
@prim_attr_register
def __init__(self):
"""init AcoshGrad"""
def infer_shape(self, x, dout):
validator.check("x shape", x, "dout shape", dout, Rel.EQ, self.name)
return x
def infer_dtype(self, x, dout):
args = {"x": x, "dout": dout}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
return x
class BatchNormGrad(PrimitiveWithInfer):
"""Performs grad of BatchNorm operation."""
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Inner operators."""
from ..._checkparam import Validator as validator
from ...common import dtype as mstype
from ..primitive import PrimitiveWithInfer, prim_attr_register
class ExtractImagePatches(PrimitiveWithInfer):
"""
Extract patches from images.
The input tensor must be a 4-D tensor and the data format is NHWC.
Args:
ksizes (Union[tuple[int], list[int]]): The size of sliding window, should be a tuple or list of int,
and the format is [1, ksize_row, ksize_col, 1].
strides (Union[tuple[int], list[int]]): Distance between the centers of the two consecutive patches,
should be a tuple or list of int, and the format is [1, stride_row, stride_col, 1].
rates (Union[tuple[int], list[int]]): In each extracted patch, the gap between the corresponding dim
pixel positions, should be a tuple or list of int, and the format is [1, rate_row, rate_col, 1].
padding (str): The type of padding algorithm, is a string whose value is "same" or "valid",
not case sensitive. Default: "valid".
- same: Means that the patch can take the part beyond the original image, and this part is filled with 0.
- valid: Means that the patch area taken must be completely contained in the original image.
Inputs:
- **input_x** (Tensor) - A 4-D tensor whose shape is [in_batch, in_row, in_col, in_depth] and
data type is int8, float16, uint8.
Outputs:
Tensor, a 4-D tensor whose data type is same as 'input_x',
and the shape is [out_batch, out_row, out_col, out_depth], the out_batch is same as the in_batch.
"""
@prim_attr_register
def __init__(self, ksizes, strides, rates, padding="valid"):
"""init"""
def _check_tuple_or_list(arg_name, arg_val, prim_name):
validator.check_value_type(f"{arg_name}s", ksizes, [tuple, list], self.name)
if len(arg_val) != 4 or arg_val[0] != 1 or arg_val[3] != 1:
raise ValueError(f"For \'{prim_name}\' the format of {arg_name}s should be [1, {arg_name}_row, "
f"{arg_name}_col, 1], but got {arg_val}.")
if not isinstance(arg_val[1], int) or not isinstance(arg_val[2], int) or arg_val[1] < 1 or arg_val[2] < 1:
raise ValueError(f"For '{prim_name}' the {arg_name}_row and {arg_name}_col in {arg_name}s should be an "
f"positive integer number, but got {arg_name}_row is {arg_val[1]}, {arg_name}_col "
f"is {arg_val[2]}")
_check_tuple_or_list("ksize", ksizes, self.name)
_check_tuple_or_list("stride", strides, self.name)
_check_tuple_or_list("rate", rates, self.name)
self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME'], self.name)
self.add_prim_attr("padding", self.padding)
def infer_shape(self, input_x):
"""infer shape"""
in_batch, in_row, in_col, in_depth = input_x
_, ksize_row, ksize_col, _ = self.ksizes
_, stride_row, stride_col, _ = self.strides
_, rate_row, rate_col, _ = self.rates
if len(input_x) != 4:
raise ValueError("The `input_x` should be a 4-D tensor, "
f"but got a {len(input_x)}-D tensor whose shape is {input_x}")
out_batch = in_batch
out_depth = ksize_row * ksize_col * in_depth
if self.padding == "VALID":
out_row = \
(in_row - (ksize_row + (ksize_row - 1) * (rate_row - 1))) // stride_row + 1
out_col = \
(in_col - (ksize_col + (ksize_col - 1) * (rate_col - 1))) // stride_col + 1
else:
out_row = (in_row - 1) // stride_row + 1
out_col = (in_col - 1) // stride_col + 1
out_shape = [out_batch, out_row, out_col, out_depth]
return out_shape
def infer_dtype(self, input_x):
"""infer dtype"""
validator.check_tensor_type_same({"input_x": input_x}, (mstype.int8, mstype.float16, mstype.float32), self.name)
return input_x
......@@ -2654,82 +2654,6 @@ class ApplyFtrl(PrimitiveWithInfer):
return var_type
class ExtractImagePatches(PrimitiveWithInfer):
"""
Extract patches from images.
The input tensor must be a 4-D tensor and the data format is NHWC.
Args:
ksizes (Union[tuple[int], list[int]]): The size of sliding window, should be a tuple or list of int,
and the format is [1, ksize_row, ksize_col, 1].
strides (Union[tuple[int], list[int]]): Distance between the centers of the two consecutive patches,
should be a tuple or list of int, and the format is [1, stride_row, stride_col, 1].
rates (Union[tuple[int], list[int]]): In each extracted patch, the gap between the corresponding dim
pixel positions, should be a tuple or list of int, and the format is [1, rate_row, rate_col, 1].
padding (str): The type of padding algorithm, is a string whose value is "same" or "valid",
not case sensitive. Default: "valid".
- same: Means that the patch can take the part beyond the original image, and this part is filled with 0.
- valid: Means that the patch area taken must be completely contained in the original image.
Inputs:
- **input_x** (Tensor) - A 4-D tensor whose shape is [in_batch, in_row, in_col, in_depth] and
data type is int8, float16, uint8.
Outputs:
Tensor, a 4-D tensor whose data type is same as 'input_x',
and the shape is [out_batch, out_row, out_col, out_depth], the out_batch is same as the in_batch.
"""
@prim_attr_register
def __init__(self, ksizes, strides, rates, padding="valid"):
"""init"""
def _check_tuple_or_list(arg_name, arg_val, prim_name):
validator.check_value_type(f"{arg_name}s", ksizes, [tuple, list], self.name)
if len(arg_val) != 4 or arg_val[0] != 1 or arg_val[3] != 1:
raise ValueError(f"For \'{prim_name}\' the format of {arg_name}s should be [1, {arg_name}_row, "
f"{arg_name}_col, 1], but got {arg_val}.")
if not isinstance(arg_val[1], int) or not isinstance(arg_val[2], int) or arg_val[1] < 1 or arg_val[2] < 1:
raise ValueError(f"For '{prim_name}' the {arg_name}_row and {arg_name}_col in {arg_name}s should be an "
f"positive integer number, but got {arg_name}_row is {arg_val[1]}, {arg_name}_col "
f"is {arg_val[2]}")
_check_tuple_or_list("ksize", ksizes, self.name)
_check_tuple_or_list("stride", strides, self.name)
_check_tuple_or_list("rate", rates, self.name)
self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME'], self.name)
self.add_prim_attr("padding", self.padding)
def infer_shape(self, input_x):
in_batch, in_row, in_col, in_depth = input_x
_, ksize_row, ksize_col, _ = self.ksizes
_, stride_row, stride_col, _ = self.strides
_, rate_row, rate_col, _ = self.rates
if len(input_x) != 4:
raise ValueError("The `input_x` should be a 4-D tensor, "
f"but got a {len(input_x)}-D tensor whose shape is {input_x}")
out_batch = in_batch
out_depth = ksize_row * ksize_col * in_depth
if self.padding == "VALID":
out_row = \
(in_row - (ksize_row + (ksize_row - 1) * (rate_row - 1))) // stride_row + 1
out_col = \
(in_col - (ksize_col + (ksize_col - 1) * (rate_col - 1))) // stride_col + 1
else:
out_row = (in_row - 1) // stride_row + 1
out_col = (in_col - 1) // stride_col + 1
out_shape = [out_batch, out_row, out_col, out_depth]
return out_shape
def infer_dtype(self, input_x):
validator.check_tensor_type_same({"input_x": input_x}, (mstype.int8, mstype.float16, mstype.float32), self.name)
return input_x
class ConfusionMulGrad(PrimitiveWithInfer):
"""
`output0` is the result of which input0 dot multily input1.
......
......@@ -265,8 +265,8 @@ test_case_math_ops = [
'desc_bprop': [[2, 3]]}),
('Acosh', {
'block': P.Acosh(),
'desc_inputs': [Tensor(np.random.rand(4).astype(np.float16))],
'skip': ['backward']}),
'desc_inputs': [[3, 4, 5]],
'desc_bprop': [[3, 4, 5]]}),
('Sin', {
'block': P.Sin(),
'desc_inputs': [[2, 3]],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册