array_ops_vm_impl.py 6.5 KB
Newer Older
Z
zhunaipan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Generate vm_impl function for array ops"""
import numpy as np

import mindspore.common.dtype as mstype
J
jinyaohui 已提交
19 20
from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P
21
from mindspore.ops.operations import _grad_ops as G
Z
zhunaipan 已提交
22 23
from mindspore.ops.vm_impl_registry import vm_impl_registry as vm_impl_getters
from .vm_interface import vm
J
jinyaohui 已提交
24 25


Z
zhunaipan 已提交
26 27 28 29 30 31
# pylint: disable=unused-argument


@vm_impl_getters.register(P.ExpandDims)
def vm_impl_expand_dims(self):
    """Generate vm_impl function for ExpandDims"""
J
jinyaohui 已提交
32

Z
zhunaipan 已提交
33 34 35 36 37 38
    def vm_impl(x, axis):
        if isinstance(x, float):
            x = Tensor(np.array([x]))
        x = x.asnumpy()
        out = vm.expand_dims(x, axis)
        return Tensor(out)
J
jinyaohui 已提交
39

Z
zhunaipan 已提交
40 41 42 43 44 45
    return vm_impl


@vm_impl_getters.register(P.DType)
def vm_impl_dType(self):
    """Generate vm_impl function for DType"""
J
jinyaohui 已提交
46

Z
zhunaipan 已提交
47 48 49
    def vm_impl(x):
        # update the src type
        return x.dtype()
J
jinyaohui 已提交
50

Z
zhunaipan 已提交
51 52 53 54 55 56
    return vm_impl


@vm_impl_getters.register(P.Cast)
def vm_impl_cast(self):
    """Generate vm_impl function for Cast"""
J
jinyaohui 已提交
57

Z
zhunaipan 已提交
58 59 60 61 62 63 64
    def vm_impl(x, t):
        if isinstance(t, type(mstype.tensor)):
            t = t.element_type()
        # update the src type
        x = x.asnumpy()
        out = x.astype(mstype.dtype_to_nptype(t))
        return Tensor(out)
J
jinyaohui 已提交
65

Z
zhunaipan 已提交
66 67 68 69 70 71
    return vm_impl


@vm_impl_getters.register(P.Reshape)
def vm_impl_reshape(self):
    """Generate vm_impl function for Reshape"""
J
jinyaohui 已提交
72

Z
zhunaipan 已提交
73 74 75 76
    def vm_impl(x, shp):
        x = x.asnumpy()
        out = vm.reshape(x, shp)
        return Tensor(out)
J
jinyaohui 已提交
77

Z
zhunaipan 已提交
78 79 80 81 82 83
    return vm_impl


@vm_impl_getters.register(P.Shape)
def vm_impl_shape(self):
    """Generate vm_impl function for Shape"""
J
jinyaohui 已提交
84

Z
zhunaipan 已提交
85 86 87
    def vm_impl(x):
        shp = vm.shape(x.asnumpy())
        return shp
J
jinyaohui 已提交
88

Z
zhunaipan 已提交
89 90 91 92 93 94
    return vm_impl


@vm_impl_getters.register(P.Squeeze)
def vm_impl_squeeze(self):
    """Generate vm_impl function for Squeeze"""
J
jinyaohui 已提交
95

Z
zhunaipan 已提交
96 97 98 99
    def vm_impl(x):
        x = x.asnumpy()
        out = vm.squeeze(x, self.axis)
        return Tensor(out)
J
jinyaohui 已提交
100

Z
zhunaipan 已提交
101 102 103 104 105 106
    return vm_impl


@vm_impl_getters.register(P.Transpose)
def vm_impl_transpose(self):
    """Generate vm_impl function for Transpose"""
J
jinyaohui 已提交
107

Z
zhunaipan 已提交
108 109 110 111 112 113
    def vm_impl(x, perm=None):
        x = x.asnumpy()
        if perm is None:
            perm = [i for i in reversed(range(len(x.shape)))]
        out = vm.transpose(x, perm)
        return Tensor(out)
J
jinyaohui 已提交
114

Z
zhunaipan 已提交
115 116 117 118 119 120
    return vm_impl


@vm_impl_getters.register(P.Split)
def vm_impl_split(self):
    """Generate vm_impl function for Split"""
J
jinyaohui 已提交
121

Z
zhunaipan 已提交
122 123 124 125
    def vm_impl(x):
        x = x.asnumpy()
        output = np.array_split(x, (self.pos,))
        return Tensor(output[0]), Tensor(output[1])
J
jinyaohui 已提交
126

Z
zhunaipan 已提交
127 128 129 130 131 132
    return vm_impl


@vm_impl_getters.register(P.Fill)
def vm_impl_fill(self):
    """Generate vm_impl function for Fill"""
J
jinyaohui 已提交
133

Z
zhunaipan 已提交
134 135 136 137 138 139
    def vm_impl(dims, x):
        if isinstance(x, int):
            ret = np.full(dims, x, np.int32)
        else:
            ret = np.full(dims, x, np.float32)
        return Tensor(ret)
J
jinyaohui 已提交
140

Z
zhunaipan 已提交
141 142 143 144 145 146
    return vm_impl


@vm_impl_getters.register(P.Eye)
def vm_impl_eye(self):
    """Generate vm_impl function for Eye"""
J
jinyaohui 已提交
147

Z
zhunaipan 已提交
148 149 150 151
    def vm_impl(n, m, t):
        np_type = mstype.dtype_to_nptype(t)
        ret = np.eye(n, m, dtype=np_type)
        return Tensor(ret)
J
jinyaohui 已提交
152

Z
zhunaipan 已提交
153 154 155 156 157 158
    return vm_impl


@vm_impl_getters.register(P.InvertPermutation)
def vm_impl_invert_permutation(self):
    """Generate vm_impl function for InvertPermutation"""
J
jinyaohui 已提交
159

Z
zhunaipan 已提交
160 161 162
    def vm_impl(x):
        out = vm.invert_permutation(x)
        return out
J
jinyaohui 已提交
163

Z
zhunaipan 已提交
164 165 166 167 168 169
    return vm_impl


@vm_impl_getters.register(P.Argmax)
def vm_impl_argmax(self):
    """Generate vm_impl function for Argmax"""
J
jinyaohui 已提交
170

Z
zhunaipan 已提交
171 172 173
    def vm_impl(x):
        output = np.argmax(x.asnumpy(), axis=self.axis)
        return Tensor(output.ravel())
J
jinyaohui 已提交
174

Z
zhunaipan 已提交
175 176
    return vm_impl

J
jinyaohui 已提交
177

Z
zhunaipan 已提交
178 179 180
@vm_impl_getters.register(P.Tile)
def vm_impl_tile(self):
    """Generate vm_impl function for Tile"""
J
jinyaohui 已提交
181

Z
zhunaipan 已提交
182 183 184 185 186
    def vm_impl(x, multiples):
        x = x.asnumpy()
        multiples = multiples.asnumpy()
        out = vm.Tile(x, multiples)
        return Tensor(out)
J
jinyaohui 已提交
187

Z
zhunaipan 已提交
188 189 190 191 192 193
    return vm_impl


@vm_impl_getters.register(P.ReduceAll)
def vm_impl_all(self):
    """Generate vm_impl function for All"""
J
jinyaohui 已提交
194

Z
zhunaipan 已提交
195 196 197 198
    def vm_impl(x, axis):
        x = x.asnumpy()
        out = vm.all(x, axis)
        return Tensor(out)
J
jinyaohui 已提交
199

Z
zhunaipan 已提交
200 201 202 203 204 205
    return vm_impl


@vm_impl_getters.register(P.Concat)
def vm_impl_concatV2(self):
    """Generate vm_impl function for Concat"""
J
jinyaohui 已提交
206

Z
zhunaipan 已提交
207 208 209 210
    def vm_impl(x):
        x = x.asnumpy()
        out = vm.Concat(x, self.axis)
        return Tensor(out)
J
jinyaohui 已提交
211

Z
zhunaipan 已提交
212 213 214 215 216 217
    return vm_impl


@vm_impl_getters.register(P.Slice)
def vm_impl_slice(self):
    """Generate vm_impl function for Slice"""
J
jinyaohui 已提交
218

Z
zhunaipan 已提交
219 220 221 222 223 224
    def vm_impl(x, begin, size):
        x = x.asnumpy()
        begin = begin.asnumpy()
        size = size.asnumpy()
        out = vm.Slice(x, begin, size)
        return Tensor(out)
J
jinyaohui 已提交
225

Z
zhunaipan 已提交
226 227 228
    return vm_impl


229
@vm_impl_getters.register(G.ConcatOffset)
Z
zhunaipan 已提交
230 231
def vm_impl_concatOffset(self):
    """Generate vm_impl function for ConcatOffset"""
J
jinyaohui 已提交
232

Z
zhunaipan 已提交
233
    def vm_impl(x):
J
jinyaohui 已提交
234
        out = vm.ConcatOffset(x)  # out is tuple
Z
zhunaipan 已提交
235
        return out
J
jinyaohui 已提交
236

Z
zhunaipan 已提交
237 238 239 240 241 242
    return vm_impl


@vm_impl_getters.register(P.ReduceSum)
def vm_impl_sum(self):
    """Generate vm_impl function for Sum"""
J
jinyaohui 已提交
243

Z
zhunaipan 已提交
244 245 246 247
    def vm_impl(x, axis):
        x = x.asnumpy()
        out = vm.sum(x, axis)
        return Tensor(np.array(out))
J
jinyaohui 已提交
248

Z
zhunaipan 已提交
249 250 251 252 253 254
    return vm_impl


@vm_impl_getters.register(P.Select)
def vm_impl_select(self):
    """Generate vm_impl function for Select"""
J
jinyaohui 已提交
255

Z
zhunaipan 已提交
256 257 258 259 260 261 262 263 264 265 266 267
    def vm_impl(cond, x, y):
        """
        Args:
            cond: A `Tensor` of type `bool`
            x: A Tensor which may have the same shape as `condition`.
            y: A `Tensor` with the same shape and type as `x`.
        """
        cond = cond.asnumpy()
        x = x.asnumpy()
        y = y.asnumpy()
        out = vm.select(cond, x, y)
        return Tensor(out)
J
jinyaohui 已提交
268

Z
zhunaipan 已提交
269 270 271 272 273 274
    return vm_impl


@vm_impl_getters.register(P.Square)
def vm_impl_square(self):
    """Generate vm_impl function for Square"""
J
jinyaohui 已提交
275

Z
zhunaipan 已提交
276 277 278
    def vm_impl(x):
        x = x.asnumpy()
        return Tensor(x * x)
J
jinyaohui 已提交
279

Z
zhunaipan 已提交
280
    return vm_impl