From 3f38b1a00e2ef6149b29347bace2c6ab4a09ed9b Mon Sep 17 00:00:00 2001 From: z00478463 Date: Wed, 27 May 2020 10:48:45 +0800 Subject: [PATCH] for comments for debug for DEBUG for DEBUG for DEBUG for DEBUG for well performance for pylint for te chip for pylint for pylint nth --- example/resnet50_imagenet2012_THOR/config.py | 2 +- example/resnet50_imagenet2012_THOR/eval.py | 60 +++++ .../resnet50_imagenet2012_THOR/model/thor.py | 1 + .../resnet50_imagenet2012_THOR/run_infer.sh | 64 +++++ example/resnet50_imagenet2012_THOR/train.py | 2 +- .../matmul_cube_fracz_left_cast_impl.py | 76 +++--- mindspore/ops/operations/thor_ops.py | 225 +++++++++++++++++- 7 files changed, 378 insertions(+), 52 deletions(-) create mode 100755 example/resnet50_imagenet2012_THOR/eval.py create mode 100755 example/resnet50_imagenet2012_THOR/run_infer.sh diff --git a/example/resnet50_imagenet2012_THOR/config.py b/example/resnet50_imagenet2012_THOR/config.py index fc01287cc..cd0a81d5e 100644 --- a/example/resnet50_imagenet2012_THOR/config.py +++ b/example/resnet50_imagenet2012_THOR/config.py @@ -23,7 +23,7 @@ config = ed({ "loss_scale": 128, "momentum": 0.9, "weight_decay": 5e-4, - "epoch_size": 50, + "epoch_size": 45, "buffer_size": 1000, "image_height": 224, "image_width": 224, diff --git a/example/resnet50_imagenet2012_THOR/eval.py b/example/resnet50_imagenet2012_THOR/eval.py new file mode 100755 index 000000000..db82b9fca --- /dev/null +++ b/example/resnet50_imagenet2012_THOR/eval.py @@ -0,0 +1,60 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +eval. +""" +import os +import argparse +from dataset_imagenet import create_dataset +from config import config +from mindspore import context +from mindspore.model_zoo.resnet import resnet50 +from mindspore.train.model import Model +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from crossentropy import CrossEntropy + +parser = argparse.ArgumentParser(description='Image classification') +parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute') +parser.add_argument('--device_num', type=int, default=1, help='Device num.') +parser.add_argument('--do_train', type=bool, default=False, help='Do train or not.') +parser.add_argument('--do_eval', type=bool, default=True, help='Do eval or not.') +parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') +parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') +args_opt = parser.parse_args() + +device_id = int(os.getenv('DEVICE_ID')) + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False) +context.set_context(device_id=device_id) + +if __name__ == '__main__': + + net = resnet50(class_num=config.class_num) + if not config.label_smooth: + config.label_smooth_factor = 0.0 + loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num) + + if args_opt.do_eval: + dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, batch_size=config.batch_size) + step_size = dataset.get_dataset_size() + + if args_opt.checkpoint_path: + param_dict = load_checkpoint(args_opt.checkpoint_path) + load_param_into_net(net, param_dict) + net.set_train(False) + + model = Model(net, loss_fn=loss, metrics={'acc'}) + res = model.eval(dataset) + print("result:", res, "ckpt=", args_opt.checkpoint_path) diff --git a/example/resnet50_imagenet2012_THOR/model/thor.py b/example/resnet50_imagenet2012_THOR/model/thor.py index d414f2385..0da1714fe 100644 --- a/example/resnet50_imagenet2012_THOR/model/thor.py +++ b/example/resnet50_imagenet2012_THOR/model/thor.py @@ -21,6 +21,7 @@ from mindspore.common.tensor import Tensor from mindspore.nn.optim.optimizer import Optimizer from mindspore.ops import functional as F, composite as C, operations as P from mindspore.parallel._utils import _get_device_num, _get_mirror_mean +from model.grad_reducer_thor import DistributedGradReducerThor momentum_opt = C.MultitypeFuncGraph("momentum_opt") diff --git a/example/resnet50_imagenet2012_THOR/run_infer.sh b/example/resnet50_imagenet2012_THOR/run_infer.sh new file mode 100755 index 000000000..14d7faf98 --- /dev/null +++ b/example/resnet50_imagenet2012_THOR/run_infer.sh @@ -0,0 +1,64 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 2 ] +then + echo "Usage: sh run_infer.sh [DATASET_PATH] [CHECKPOINT_PATH]" +exit 1 +fi + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +PATH1=$(get_real_path $1) +PATH2=$(get_real_path $2) + + +if [ ! -d $PATH1 ] +then + echo "error: DATASET_PATH=$1 is not a directory" +exit 1 +fi + +if [ ! -f $PATH2 ] +then + echo "error: CHECKPOINT_PATH=$2 is not a file" +exit 1 +fi + +ulimit -u unlimited +export DEVICE_NUM=1 +export DEVICE_ID=0 +export RANK_SIZE=$DEVICE_NUM +export RANK_ID=0 + +if [ -d "infer" ]; +then + rm -rf ./infer +fi +mkdir ./infer +cp *.py ./infer +cp *.sh ./infer +cd ./infer || exit +env > env.log +echo "start infering for device $DEVICE_ID" +python eval.py --do_eval=True --dataset_path=$PATH1 --checkpoint_path=$PATH2 &> log & +cd .. diff --git a/example/resnet50_imagenet2012_THOR/train.py b/example/resnet50_imagenet2012_THOR/train.py index b98d13b8a..15710bc66 100644 --- a/example/resnet50_imagenet2012_THOR/train.py +++ b/example/resnet50_imagenet2012_THOR/train.py @@ -109,7 +109,7 @@ if __name__ == '__main__': step_size = dataset.get_dataset_size() loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) - lr = Tensor(get_model_lr(0, 0.05, 6, 70, 5004)) + lr = Tensor(get_model_lr(0, 0.045, 6, 70, 5004)) opt = THOR(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, filter(lambda x: 'matrix_A' in x.name, net.get_parameters()), filter(lambda x: 'matrix_G' in x.name, net.get_parameters()), diff --git a/mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py b/mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py index 9a30da378..11b668445 100644 --- a/mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +++ b/mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py @@ -486,41 +486,41 @@ def cus_cube_matmul_cast(tik_instance, input_x1, trans_a, input_x2, trans_b, input_x2_cast_ub[count * repeate_times_max * vectorfp32_size], input_x2_ub[count * repeate_times_max * vectorfp32_size], repeate_num, 1, 1, 4, 8) - input_x2_L1 = tik_instance.Tensor("float16", [no_tile, ko_tile_inner, c0, c0], - name="input_x2_L1", scope=tik.scope_cbuf) - tik_instance.data_move(input_x2_L1, input_x2_cast_ub, 0, 1, - no_tile * ko_tile_inner * c0 * c0 * fp16_size // blocksize, 0, 0) - # input_x1 -> input_x1_L1 - input_x1_L1 = tik_instance.Tensor(input_x1.dtype, [ko_tile_inner, mo_tile, c0, c0], - name="input_x1_L1", scope=tik.scope_cbuf) - tik_instance.data_move(input_x1_L1, - input_x1[k_idx, - core_m * mo_tile, 0, 0], - 0, ko_tile_inner, mo_tile * c0 * c0 * fp16_size // blocksize, - (mo - mo_tile) * c0 * c0 * fp16_size // blocksize, 0) - # input_x2_L1 -> input_x2_L0B - input_x2_L0B = tik_instance.Tensor("float16", [ko_tile_inner, no_tile, c0, c0], - name="input_x2_L0B", scope=tik.scope_cb) - with tik_instance.for_range(0, ko_tile_inner) as cc2: - tik_instance.load2dv1(input_x2_L0B[cc2, 0, 0, 0], input_x2_L1[0, cc2, 0, 0], 0, no_tile, - ko_tile_inner, - 0, True) - # input_x1_L1 -> input_x1_L0A - input_x1_L0A = tik_instance.Tensor(input_x1.dtype, [mo_tile, ko_tile_inner, c0, c0], - name="input_x1_L0A", scope=tik.scope_ca) - with tik_instance.for_range(0, mo_tile) as cc1: - tik_instance.load2dv1(input_x1_L0A[cc1, 0, 0, 0], input_x1_L1[0, cc1, 0, 0], 0, ko_tile_inner, - mo_tile, 0, False) - with tik_instance.if_scope(thread_idx_k == 0): - tik_instance.mmad(res_L0C, input_x1_L0A, input_x2_L0B, mo_tile * c0, - ko_tile_inner * c0, no_tile * c0, 0) - with tik_instance.else_scope(): - tik_instance.mmad(res_L0C, input_x1_L0A, input_x2_L0B, mo_tile * c0, - ko_tile_inner * c0, no_tile * c0, 1) - res_ub = tik_instance.Tensor(input_x1.dtype, [no_tile, mo_tile, c0, c0], - name="resMatmul_ub", scope=tik.scope_ubuf) - tik_instance.data_move(res_ub, res_L0C, 0, 1, no_tile * mo_tile, 0, 0, 1) - tik_instance.data_move(res[(core_n * loop_n_num + cc_n) * no_tile, core_m * mo_tile, 0, 0], - res_ub, 0, no_tile, - mo_tile * c0 * c0 * fp16_size // blocksize, 0, - (mo - mo_tile) * c0 * c0 * fp16_size // blocksize) + input_x2_L1 = tik_instance.Tensor("float16", [no_tile, ko_tile_inner, c0, c0], + name="input_x2_L1", scope=tik.scope_cbuf) + tik_instance.data_move(input_x2_L1, input_x2_cast_ub, 0, 1, + no_tile * ko_tile_inner * c0 * c0 * fp16_size // blocksize, 0, 0) + # input_x1 -> input_x1_L1 + input_x1_L1 = tik_instance.Tensor(input_x1.dtype, [ko_tile_inner, mo_tile, c0, c0], + name="input_x1_L1", scope=tik.scope_cbuf) + tik_instance.data_move(input_x1_L1, + input_x1[k_idx, + core_m * mo_tile, 0, 0], + 0, ko_tile_inner, mo_tile * c0 * c0 * fp16_size // blocksize, + (mo - mo_tile) * c0 * c0 * fp16_size // blocksize, 0) + # input_x2_L1 -> input_x2_L0B + input_x2_L0B = tik_instance.Tensor("float16", [ko_tile_inner, no_tile, c0, c0], + name="input_x2_L0B", scope=tik.scope_cb) + with tik_instance.for_range(0, ko_tile_inner) as cc2: + tik_instance.load2dv1(input_x2_L0B[cc2, 0, 0, 0], input_x2_L1[0, cc2, 0, 0], 0, no_tile, + ko_tile_inner, + 0, True) + # input_x1_L1 -> input_x1_L0A + input_x1_L0A = tik_instance.Tensor(input_x1.dtype, [mo_tile, ko_tile_inner, c0, c0], + name="input_x1_L0A", scope=tik.scope_ca) + with tik_instance.for_range(0, mo_tile) as cc1: + tik_instance.load2dv1(input_x1_L0A[cc1, 0, 0, 0], input_x1_L1[0, cc1, 0, 0], 0, ko_tile_inner, + mo_tile, 0, False) + with tik_instance.if_scope(thread_idx_k == 0): + tik_instance.mmad(res_L0C, input_x1_L0A, input_x2_L0B, mo_tile * c0, + ko_tile_inner * c0, no_tile * c0, 0) + with tik_instance.else_scope(): + tik_instance.mmad(res_L0C, input_x1_L0A, input_x2_L0B, mo_tile * c0, + ko_tile_inner * c0, no_tile * c0, 1) + res_ub = tik_instance.Tensor(input_x1.dtype, [no_tile, mo_tile, c0, c0], + name="resMatmul_ub", scope=tik.scope_ubuf) + tik_instance.data_move(res_ub, res_L0C, 0, 1, no_tile * mo_tile, 0, 0, 1) + tik_instance.data_move(res[(core_n * loop_n_num + cc_n) * no_tile, core_m * mo_tile, 0, 0], + res_ub, 0, no_tile, + mo_tile * c0 * c0 * fp16_size // blocksize, 0, + (mo - mo_tile) * c0 * c0 * fp16_size // blocksize) diff --git a/mindspore/ops/operations/thor_ops.py b/mindspore/ops/operations/thor_ops.py index 5e6ff4b95..54d097b0c 100644 --- a/mindspore/ops/operations/thor_ops.py +++ b/mindspore/ops/operations/thor_ops.py @@ -13,10 +13,11 @@ # limitations under the License. # ============================================================================ """thor_ops""" -import mindspore as ms from mindspore.ops import prim_attr_register, PrimitiveWithInfer from mindspore.ops.composite import multitype_ops as C +import mindspore as ms + __all__ = ["CusBatchMatMul", "CusCholeskyTrsm", "CusFusedAbsMax1", @@ -33,12 +34,31 @@ __all__ = ["CusBatchMatMul", class CusBatchMatMul(PrimitiveWithInfer): """CusBatchMatMul definition""" + """ + Multiplies matrix `a` by matrix `b` in batch. + + The rank of input tensors must be `3`. + + Inputs: + - **input_x** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(N, D, D)`. If + - **input_y** (Tensor) - The second tensor to be multiplied. The shape of the tensor is :math:`(N, D, D)`. If + `transpose_b` is True. + + Outputs: + Tensor, the shape of the output tensor is :math:`(N, D, D)`. + + Examples: + >>> input_x = Tensor(np.ones(shape=[2, 128, 128]), mindspore.float32) + >>> input_y = Tensor(np.ones(shape=[2, 128, 128]), mindspore.float32) + >>> cus_batch_matmul = P.CusBatchMatMul() + >>> output = cus_batch_matmul(input_x, input_y) + """ @prim_attr_register def __init__(self): """init CusBatchMatMul""" self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['y']) - + from mindspore.ops._op_impl._custom_op.batch_matmul_impl import CusBatchMatMul def get_bprop(self): def bprop(x1, x2, out, dout): return (C.zeros_like(x1), C.zeros_like(x2)) @@ -54,12 +74,30 @@ class CusBatchMatMul(PrimitiveWithInfer): class CusCholeskyTrsm(PrimitiveWithInfer): """CusCholeskyTrsm definition""" + """ + L * LT = A. + LT * (LT)^-1 = I. + return (LT)^-1. + Only compute the res of the diag part of input matrix with dim 128. + The rank of input tensors must be `2`. + + Inputs: + - **input_x** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(N, N)`. + + Outputs: + Tensor, the shape of the output tensor is :math:`(N // Split_dim, Split_dim, Split_dim)`. + + Examples: + >>> input_x = Tensor(np.ones(shape=[256, 256]), mindspore.float32) + >>> cus_choleskytrsm = P.CusCholeskyTrsm() + >>> output = matmul(input_x) + """ @prim_attr_register def __init__(self): """init CusCholeskyTrsm""" self.init_prim_io_names(inputs=['x1'], outputs=['y']) - + from mindspore.ops._op_impl._custom_op.cholesky_trsm_impl import CusCholeskyTrsm def infer_shape(self, data1_shape): ll = [] m, _ = data1_shape @@ -75,13 +113,28 @@ class CusCholeskyTrsm(PrimitiveWithInfer): class CusFusedAbsMax1(PrimitiveWithInfer): """CusFusedAbsMax1 definition""" + """ + Compute the abs max of Tensor input. + + The rank of input tensors must be `4` or `2`. + Inputs: + - **input_x** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(N0, M0, N1, M1)` + or math:`(32, 64)`. + Outputs: + Tensor, the shape of the output tensor is :math:`(32, 64)` or math:`(1, )`. + + Examples: + >>> input_x = Tensor(np.ones(shape=[1, 3]), mindspore.float32) + >>> cus_fused_abs_max1 = P.CusFusedAbsMax1() + >>> output = cus_fused_abs_max1(input_x) + """ @prim_attr_register def __init__(self, origin_shape=[-1, -1]): """init CusFusedAbsMax1""" self.init_prim_io_names(inputs=['x1'], outputs=['y']) self.origin_shape = origin_shape - + from mindspore.ops._op_impl._custom_op.fused_abs_max1_impl import CusFusedAbsMax1 def get_bprop(self): def bprop(x, out, dout): return (C.zeros_like(x),) @@ -102,6 +155,21 @@ class CusFusedAbsMax1(PrimitiveWithInfer): class CusImg2Col(PrimitiveWithInfer): """CusImg2Col definition""" + """ + Img2col the feature map and the result in reorganized in NC1HWC0. + + Args: + - **strides** (listInt) - the stride of the ops. + - **ksizes** (listInt) - the kernel size of the ops. + Inputs: + - **input_x** (Tensor) - The shape of the tensor is :math:`(N, C, H, W)`. + Outputs: + Tensor, the shape of the output tensor is :math:`(N * H_O * W_O, C1 * K_W * K_H * C0)`. + Examples: + >>> input_x = Tensor(np.ones(shape=[32, 3, 224, 224]), mindspore.float16) + >>> cusimg2col = P.CusImg2Col() + >>> output = cusimg2col(input_x) + """ @prim_attr_register def __init__(self, ksizes, strides, dilates=(1, 1, 1, 1), mode="NC1HWC0"): @@ -111,7 +179,7 @@ class CusImg2Col(PrimitiveWithInfer): self.strides = strides self.dilates = dilates self.mode = mode - + from mindspore.ops._op_impl._custom_op.img2col_impl import CusImg2Col def get_bprop(self): def bprop(x, out, dout): return (C.zeros_like(x),) @@ -136,12 +204,30 @@ class CusImg2Col(PrimitiveWithInfer): class CusMatMulCubeDenseLeft(PrimitiveWithInfer): """CusMatMulCube definition""" + """ + Multiplies matrix `a` by matrix `b`. + + The rank of input_x1 must be `4`, the fractal format of the normal matrix. + The rank of input_x2 must be `2`. + + Inputs: + - **input_x1** (Tensor) - The first tensor to be multiplied. + The shape of the tensor is :math:`(N0, M0, N1, M1)`. + - **input_x2** (Tensor) - The second tensor to be multiplied. The shape of the tensor is :math:`(M, C)`. + Outputs: + Tensor, the shape of the output tensor is :math:`(N, C)`. + Examples: + >>> input_x = Tensor(np.ones(shape=[16, 16, 16, 16]), mindspore.float16) + >>> input_y = Tensor(np.ones(shape=[256, 256]), mindspore.float16) + >>> matmulcubedenseleft = P.CusMatMulCubeDenseLeft() + >>> output = matmulcubedenseleft(input_x, input_y) + """ @prim_attr_register def __init__(self): """init CusMatMulCubeDenseLeft""" self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['y']) - + from mindspore.ops._op_impl._custom_op.matmul_cube_dense_left_impl import CusMatMulCubeDenseLeft def get_bprop(self): def bprop(x1, x2, out, dout): return (C.zeros_like(x1), C.zeros_like(x2)) @@ -157,12 +243,32 @@ class CusMatMulCubeDenseLeft(PrimitiveWithInfer): class CusMatMulCubeFraczRightMul(PrimitiveWithInfer): """CusMatMulCubeFraczRightMul definition""" + """ + Multiplies matrix `a` by matrix `b` and muls the result by scalar `c`. + + The rank of input_x1 tensors must be `2`. + The rank of input_x2 tensors must be `4`. + + Inputs: + - **input_x1** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(N, C)`. + - **input_x2** (Tensor) - The second tensor to be multiplied. + The shape of the tensor is :math:`(C1, M1, C0, M0)`. + - **input_x3** (Tensor) - The third tensor to be multiplied. The shape of the tensor if :math`(1, )`. + Outputs: + Tensor, the shape of the output tensor is :math:`(N, M)`. + Examples: + >>> input_x1 = Tensor(np.ones(shape=[256, 256]), mindspore.float16) + >>> input_x2 = Tensor(np.ones(shape=[16, 16, 16, 16]), mindspore.float16) + >>> input_x3 = Tensor(np.ones(shape=[1, ]), mindspore.float16) + >>> cusmatmulfraczrightmul = P.CusMatMulCubeFraczRightMul() + >>> output = cusmatmulfraczrightmul(input_x1, input_x2, input_x3) + """ @prim_attr_register def __init__(self): """init CusMatMulCubeFraczRightMul""" self.init_prim_io_names(inputs=['x1', 'x2', 'x3'], outputs=['y']) - + from mindspore.ops._op_impl._custom_op.matmul_cube_fracz_right_mul_impl import CusMatMulCubeFraczRightMul def get_bprop(self): def bprop(x1, x2, x3, out, dout): return (C.zeros_like(x1), C.zeros_like(x2), C.zeros_like(x3)) @@ -178,6 +284,30 @@ class CusMatMulCubeFraczRightMul(PrimitiveWithInfer): class CusMatMulCube(PrimitiveWithInfer): """CusMatMulCube definition""" + """ + Multiplies matrix `a` by matrix `b`. + + The rank of input tensors must be `2`. + + Args: + transpose_a (bool): If True, `a` is transposed before multiplication. Default: False. + transpose_b (bool): If True, `b` is transposed before multiplication. Default: False. + + Inputs: + - **input_x** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(N, C)`. If + `transpose_a` is True, its shape should be :math:`(N, C)` after transposing. + - **input_y** (Tensor) - The second tensor to be multiplied. The shape of the tensor is :math:`(C, M)`. If + `transpose_b` is True, its shape should be :math:`(C, M)` after transpose. + + Outputs: + Tensor, the shape of the output tensor is :math:`(N, M)`. + + Examples: + >>> input_x = Tensor(np.ones(shape=[256, 256]), mindspore.float16) + >>> input_y = Tensor(np.ones(shape=[256, 256]), mindspore.float16) + >>> cusmatmulcube = P.CusMatMulCube() + >>> output = matmul(input_x, input_y) + """ @prim_attr_register def __init__(self, transpose_a=False, transpose_b=False): @@ -185,7 +315,7 @@ class CusMatMulCube(PrimitiveWithInfer): self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['y']) self.transpose_a = transpose_a self.transpose_b = transpose_b - + from mindspore.ops._op_impl._custom_op.matmul_cube_impl import CusMatMulCube def get_bprop(self): def bprop(x1, x2, out, dout): return (C.zeros_like(x1), C.zeros_like(x2)) @@ -213,12 +343,27 @@ class CusMatMulCube(PrimitiveWithInfer): class CusMatrixCombine(PrimitiveWithInfer): """CusMatrixCombine definition""" + """ + move the batch matrix to result matrix diag part. + The rank of input tensors must be `3`. + + Inputs: + - **input_x** (Tensor) - The shape of the tensor is :math:`(N, D, D)`. + + Outputs: + Tensor, the shape of the output tensor is :math:`(N * D, N * D)`. + + Examples: + >>> input_x = Tensor(np.ones(shape=[2, 128, 128]), mindspore.float32) + >>> cusmatrixcombine = P.CusMatrixCombine() + >>> output = cusmatrixcombine(input_x) + """ @prim_attr_register def __init__(self): """init CusMatrixCombine""" self.init_prim_io_names(inputs=['x'], outputs=['y']) - + from mindspore.ops._op_impl._custom_op.matrix_combine_impl import CusMatrixCombine def get_bprop(self): def bprop(x, out, dout): return (C.zeros_like(x),) @@ -237,12 +382,28 @@ class CusMatrixCombine(PrimitiveWithInfer): class CusTranspose02314(PrimitiveWithInfer): """CusTranspose02314 definition""" + """ + Permute input tensor with perm (0, 2, 3, 1, 4) + + The rank of input tensors must be `5` with format NC1HWC0. + + Inputs: + - **input_x** (Tensor) - The shape of the tensor is :math:`(N, C1, H, W, C0)`. + + Outputs: + Tensor, the shape of the output tensor is :math:`(N, H, W, C1, C0)`. + + Examples: + >>> input_x = Tensor(np.ones(shape=[32, 1, 224, 224, 16]), mindspore.float16) + >>> custranspose02314 = P.CusTranspose02314() + >>> output = custranspose02314(input_x) + """ @prim_attr_register def __init__(self): """init CusTranspose02314""" self.init_prim_io_names(inputs=['x1'], outputs=['y']) - + from mindspore.ops._op_impl._custom_op.transpose02314_impl import CusTranspose02314 def get_bprop(self): def bprop(x, out, dout): return (C.zeros_like(x),) @@ -263,12 +424,32 @@ class CusTranspose02314(PrimitiveWithInfer): class CusMatMulCubeDenseRight(PrimitiveWithInfer): """CusMatMulCubeDenseRight definition""" + """ + Multiplies matrix `a` by matrix `b`. + + The rank of input_x1 tensor must be `2`. + The rank of input_x2 tensor must be `4`. + + Inputs: + - **input_x** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(N, C)`. + - **input_y** (Tensor) - The second tensor to be multiplied. + The shape of the tensor is :math:`(C1, M1, M0, C0)`. + + Outputs: + Tensor, the shape of the output tensor is :math:`(N, M)`. + + Examples: + >>> input_x = Tensor(np.ones(shape=[256, 256]), mindspore.float16) + >>> input_y = Tensor(np.ones(shape=[16, 16, 16, 16]), mindspore.float16) + >>> cusmatmulcubedenseright = P.CusMatMulCubeDenseRight() + >>> output = cusmatmulcubedenseright(input_x, input_y) + """ @prim_attr_register def __init__(self): """init CusMatMulCubeDenseRight""" self.init_prim_io_names(inputs=['x1', 'x2', 'x3'], outputs=['y']) - + from mindspore.ops._op_impl._custom_op.matmul_cube_dense_right_impl import CusMatMulCubeDenseRight def get_bprop(self): def bprop(x1, x2, x3, out, dout): return (C.zeros_like(x1), C.zeros_like(x2), C.zeros_like(x3)) @@ -284,12 +465,32 @@ class CusMatMulCubeDenseRight(PrimitiveWithInfer): class CusMatMulCubeFraczLeftCast(PrimitiveWithInfer): """CusMatMulCubeFraczLeftCast definition""" + """ + Multiplies matrix `a` by matrix `b`. + + The rank of input_x1 tensor must be `4`. + The rank of input_x2 tensors must be `2`. + + Inputs: + - **input_x1** (Tensor) - The first tensor to be multiplied. + The shape of the tensor is :math:`(C1, N1, N0, C0)`. + - **input_x2** (Tensor) - The second tensor to be multiplied. The shape of the tensor is :math:`(C, M)`. + + Outputs: + Tensor, the shape of the output tensor is :math:`(N, M)`. + + Examples: + >>> input_x = Tensor(np.ones(shape=[16, 16, 16, 16]), mindspore.float16) + >>> input_y = Tensor(np.ones(shape=[256, 256]), mindspore.float16) + >>> cusmatmulcubefraczleftcast = P.CusMatMulCubeFraczLeftCast() + >>> output = cusmatmulcubefraczleftcast(input_x, input_y) + """ @prim_attr_register def __init__(self): """init CusMatMulCubeFraczLeftCast""" self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['y']) - + from mindspore.ops._op_impl._custom_op.matmul_cube_fracz_left_cast_impl import CusMatMulCubeFraczLeftCast def get_bprop(self): def bprop(x1, x2, out, dout): return (C.zeros_like(x1), C.zeros_like(x2)) -- GitLab